In [49]:
import torch
import torch.nn
from torchvision.models import convnext_large, convnext_small
from torch.optim import SGD, Adam
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image
import os

In [50]:
transform = transforms.Compose([
   transforms.Resize([232], interpolation=transforms.InterpolationMode.BILINEAR),
                                transforms.CenterCrop([224,224]),
                                transforms.ToTensor(), # Rescaled to 0-1 aswell
                                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

target_transform = transforms.Compose([
                                transforms.Resize([232], interpolation=transforms.InterpolationMode.BILINEAR),
                                transforms.CenterCrop([224,224]),
                                transforms.ToTensor(),
])

In [51]:
class CustomDataset(Dataset):
    def __init__(self, path, transform=None, target_transform=None):
        self.transform= transform
        self.target_transform = target_transform
        self.path = path
        self.labels = []
        self.images = []
        self.count = 0
        for i, direc in enumerate(os.listdir(self.path)[1:]): # [1:] because MacOS has .DS_Store at 0th index
            for j, file_name in enumerate(os.listdir(os.path.join(self.path,direc))):
                self.labels.append(i)
                image_path = os.path.join(self.path, direc ,file_name)
                image = Image.open(image_path)
                image = target_transform(image)
                self.images.append(image)
        
        # Early loading
        
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # Lazy loading
        
#         directory_idx = idx//100
#         file_idx = idx%100
#         label = directory_idx

        return self.images[idx], self.labels[idx]
        

In [52]:
dataset = CustomDataset('tiny_FR', target_transform=target_transform)

In [53]:
torch.manual_seed(10)
index_train_test = torch.randperm(500)
index_train_val = torch.randperm(400)

In [54]:
data = Subset(dataset, index_train_test[:400])
data_test = Subset(dataset, index_train_test[400:])
data

<torch.utils.data.dataset.Subset at 0x2d45fabb0>

In [56]:
data_train = Subset(data, index_train_val[:300])
data_valid = Subset(data, index_train_val[300:])
data_train

<torch.utils.data.dataset.Subset at 0x2d4603d60>

In [64]:
torch.manual_seed(1)
data = DataLoader(data_train, 32, shuffle=True)
data_vl = DataLoader(data_valid, 16, shuffle=True)

In [66]:
next(iter(data))[0].shape, next(iter(data_vl))[0].shape,

(torch.Size([32, 3, 224, 224]), torch.Size([16, 3, 224, 224]))

In [68]:
model = convnext_small(weights='IMAGENET1K_V1')

In [69]:
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=

In [70]:
for param in model.parameters():
    param.requires_grad = False

In [71]:
num_ftrs = model.classifier[2].in_features
model.classifier[2] = torch.nn.Linear(num_ftrs, 5)

In [72]:
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=

In [73]:
model = model.to('mps')

In [74]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.classifier[2].parameters(), lr=0.01, momentum=0.9) #0.01 is good

In [75]:
def train(epoch, data, data_vl, PATH="model.pt"):
    train_loss = [0] * epoch
    train_accuracy = [0] * epoch
    
    valid_loss = [0] * epoch
    valid_accuracy = [0] * epoch
    model.train()
    
    low_val_loss = 1000
    for i in range(epoch):
        for x_batch, y_batch in data:
            #print(x_batch.shape)
            x_batch = x_batch.to('mps')
            y_batch = y_batch.to('mps')
            pred = model(x_batch)
            loss = loss_fn(pred, y_batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            train_loss[i] += loss.item() * x_batch.size(0)
            correct = torch.sum(torch.argmax(torch.softmax(pred,axis=1), axis=1) == y_batch)
            train_accuracy[i] += correct.to(torch.float32).item()
        train_loss[i] /= len(data.dataset)
        train_accuracy[i] /= len(data.dataset)
        
        with torch.no_grad():
            for x_batch, y_batch in data_vl:
                x_batch = x_batch.to('mps')
                y_batch = y_batch.to('mps')
                pred = model(x_batch)
                loss = loss_fn(pred, y_batch)
                valid_loss[i] += loss.item() * x_batch.size(0)
                correct = torch.sum(torch.argmax(torch.softmax(pred,axis=1), axis=1) == y_batch)
                valid_accuracy[i] += correct.item()
            valid_loss[i] /= len(data_vl.dataset)
            valid_accuracy[i] /= len(data_vl.dataset)
        
        if valid_loss[i] < low_val_loss:
            torch.save({
            'epoch': i +1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': valid_loss[i],
            }, PATH)
            low_val_loss = valid_loss[i]
        print(f'Epoch {i+1} accuracy: {train_accuracy[i]:.4f} val_accuracy:{valid_accuracy[i]:.4f}')
        print(f'Epoch {i+1} loss: {train_loss[i]:.4f} val_loss:{valid_loss[i]:.4f}')
        print()
            

In [76]:
train(30, data, data_vl)

Epoch 1 accuracy: 0.4867 val_accuracy:0.6800
Epoch 1 loss: 1.3818 val_loss:0.9353

Epoch 2 accuracy: 0.7167 val_accuracy:0.8700
Epoch 2 loss: 0.7435 val_loss:0.4928

Epoch 3 accuracy: 0.8467 val_accuracy:0.9000
Epoch 3 loss: 0.4870 val_loss:0.3956

Epoch 4 accuracy: 0.8867 val_accuracy:0.8900
Epoch 4 loss: 0.3734 val_loss:0.3919

Epoch 5 accuracy: 0.9300 val_accuracy:0.9500
Epoch 5 loss: 0.3117 val_loss:0.3271

Epoch 6 accuracy: 0.9367 val_accuracy:0.9400
Epoch 6 loss: 0.2730 val_loss:0.3193

Epoch 7 accuracy: 0.9500 val_accuracy:0.9300
Epoch 7 loss: 0.2636 val_loss:0.3243

Epoch 8 accuracy: 0.9600 val_accuracy:0.8900
Epoch 8 loss: 0.1981 val_loss:0.2968

Epoch 9 accuracy: 0.9567 val_accuracy:0.9200
Epoch 9 loss: 0.2131 val_loss:0.3131

Epoch 10 accuracy: 0.9633 val_accuracy:0.9300
Epoch 10 loss: 0.1839 val_loss:0.2923

Epoch 11 accuracy: 0.9700 val_accuracy:0.9200
Epoch 11 loss: 0.1697 val_loss:0.2914

Epoch 12 accuracy: 0.9833 val_accuracy:0.9200
Epoch 12 loss: 0.1555 val_loss:0.3004

In [77]:
checkpoint = torch.load("model.pt")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
epoch, loss

(20, 0.19982839554548262)

## Evaluation

In [78]:
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix

In [82]:
len(data_test)

100

In [83]:
data_tst = DataLoader(data_test, 100, False)
test_data = next(iter(data_tst))[0]
test_data.shape

torch.Size([100, 3, 224, 224])

In [104]:
y_pred = model(test_data.to('mps'))
y_pred = torch.argmax(torch.softmax(y_pred, axis=1),axis=1)
y_pred.shape

torch.Size([100])

In [105]:
y_test = next(iter(data_tst))[1]
y_test.shape

torch.Size([100])

In [106]:
y_pred = torch.detach(y_pred.to('cpu')).numpy()
y_test = torch.detach(y_test.to('cpu')).numpy()

In [110]:
p = precision_score(y_test, y_pred, average = 'weighted')
print(f"Precision: {p}")

r = recall_score(y_test, y_pred , average = 'weighted')
print(f"Recall: {r}")

f1 = f1_score( y_test, y_pred, average = 'weighted')
print(f"F1-score: {f1}")


Precision: 0.8439422421142371
Recall: 0.83
F1-score: 0.8297688997993876
