In [None]:
!nvidia-smi

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
from PIL import Image
import pandas as pd
import random
from tqdm import tqdm
import timm

In [None]:
import sys
sys.path.append('../')

### Set input images size

In [None]:
input_size = 256

### Create model

#### ResNext

In [None]:
model_ft = models.resnext50_32x4d(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)

#### ConvNext

In [None]:
from wmdetection.models.convnext import convnext_tiny, convnext_small

model_ft = convnext_tiny(pretrained=True, in_22k=True, num_classes=21841)

model_ft.head = nn.Sequential( 
    nn.Linear(in_features=768, out_features=512),
    nn.GELU(),
    nn.Linear(in_features=512, out_features=256),
    nn.GELU(),
    nn.Linear(in_features=256, out_features=2),
)

##### Effnet

In [None]:
model_ft = timm.create_model(
    'efficientnet_b3a', pretrained=True, num_classes=2
)
model_ft.classifier = nn.Sequential(
    nn.Linear(in_features=1536, out_features=625),
    nn.ReLU(),
    nn.Dropout(p=0.3),
    nn.Linear(in_features=625, out_features=256),
    nn.ReLU(),
    nn.Linear(in_features=256, out_features=2),
)

### Preparations for train

In [None]:
model_ft = model_ft.cuda()

In [None]:
class RandomRotation:
    def __init__(self, angles, p):
        self.p = p
        self.angles = angles

    def __call__(self, x):
        if random.random() < self.p:
            angle = random.choice(self.angles)
            return transforms.functional.rotate(x, angle)
        else:
            return x

preprocess = {
    'train': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        #transforms.RandomCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        RandomRotation([90, -90], 0.2),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        #transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

## Prepare dataset

In [None]:
df_train = pd.read_csv('../dataset/train_data_v1-1.csv')
df_val = pd.read_csv('../dataset/val_data_v1-1.csv')

df_train['path'] = df_train['path'].apply(lambda x: os.path.join('dataset', x))
df_val['path'] = df_val['path'].apply(lambda x: os.path.join('dataset', x))

In [None]:
df_train['label'].value_counts()

In [None]:
df_val['label'].value_counts()

In [None]:
from io import BytesIO
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

class WatermarkDataset(torch.utils.data.Dataset):
    def __init__(self, df, transform):
        self.df = df.reset_index(drop = True)
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img = Image.open(self.df.loc[idx].path).convert('RGB')
        tensor = self.transform(img)
        return tensor, self.df.loc[idx].label

In [None]:
train_ds = WatermarkDataset(df_train, preprocess['train'])

In [None]:
val_ds = WatermarkDataset(df_val, preprocess['val'])

In [None]:
datasets = {
    'train': train_ds,
    'val': val_ds,
}

## Train

In [None]:
from tqdm import tqdm
device = torch.device('cuda:0')

def train_model(model, dataloaders, criterion, optimizer, num_epochs=80):
    since = time.time()

    val_acc_history = []
    train_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    with torch.cuda.amp.autocast():
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
            if phase == 'val':
                val_acc_history.append(epoch_acc)
            if phase == 'train':
                train_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    return model, train_acc_history, val_acc_history

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.AdamW(params=model_ft.parameters(), lr=0.2e-5)

BATCH_SIZE = 64

dataloaders_dict = {
    x: torch.utils.data.DataLoader(datasets[x], batch_size=BATCH_SIZE, shuffle=True, num_workers=12) 
    for x in ['train', 'val']
}

In [None]:
import warnings
warnings.filterwarnings("ignore")

model_ft, train_acc_history, val_acc_history = train_model(
    model_ft, dataloaders_dict, criterion, optimizer, num_epochs=3
)

### Plot acc history

In [None]:
plt.plot([i.cpu().item() for i in train_acc_history])
plt.plot([i.cpu().item() for i in val_acc_history])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['train', 'valid'], loc='upper left')
plt.show()

### Save model

In [None]:
os.makedirs('../weights', exist_ok=True)
torch.save(model_ft.state_dict(), "../weights/convnext-t_3layer-head_inp256_datasetv1-1_3epochs_v3.pth")