In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
from app.classificator import ResClassifier, Classificator, VitClassifier

In [3]:
from torchvision.models import resnet34, ResNet, resnext101_64x4d

In [4]:

import os
from PIL import Image
from torchvision import transforms

class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.image_paths = []
        self.labels = []
        
        # Gather image paths and corresponding labels
        for idx, cls in enumerate(self.classes):
            cls_folder = os.path.join(root_dir, cls)
            if os.path.isdir(cls_folder):
                for img_name in os.listdir(cls_folder):
                    img_path = os.path.join(cls_folder, img_name)
                    self.image_paths.append(img_path)
                    self.labels.append(idx)
    
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        return image, label



In [5]:
dataset = CustomImageDataset(root_dir="/home/user1/hack/train_data_rkn/dataset", transform=transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
]))
# train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True,num_workers=4)

In [6]:
import numpy as np
np.unique(dataset.labels, return_counts=True), dataset.labels.__len__()

((array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,
          27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
          40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
          53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,
          66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
          79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
          92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104,
         105]),
  array([200, 109, 200, 200, 127, 200, 200, 180, 200, 203, 164, 200, 197,
         200, 194, 107, 195, 197, 200, 111, 200, 222, 200,  53, 200,   7,
         131, 200,  99, 200, 199, 196, 200,   5, 200,  62, 200, 200, 200,
         200, 272, 199, 200, 200, 199, 199, 200, 200, 200, 200, 200, 200,
          87, 200, 200, 200, 195, 200, 200, 375, 200, 200, 200, 123, 370,
         181, 200, 200

In [7]:
classifier = ResClassifier(resnext101_64x4d,num_classes=len(set(dataset.labels))+1)
# classifier = VitClassifier(num_classes=106)

In [8]:
classifier.tune(dataset, device="cuda", epochs=15, batch_size=128, dl_num_workers=4, lr=1e-6)

Epoch 1/15: 100%|██████████| 121/121 [02:53<00:00,  1.44s/it, loss=4.54]


Epoch 1/15 - Mean Average Precision: 0.0419


Epoch 2/15: 100%|██████████| 121/121 [02:53<00:00,  1.43s/it, loss=4.11]


Epoch 2/15 - Mean Average Precision: 0.0675


Epoch 3/15: 100%|██████████| 121/121 [02:53<00:00,  1.43s/it, loss=3.88]


Epoch 3/15 - Mean Average Precision: 0.0919


Epoch 4/15: 100%|██████████| 121/121 [02:53<00:00,  1.43s/it, loss=3.65]


Epoch 4/15 - Mean Average Precision: 0.1018


Epoch 5/15: 100%|██████████| 121/121 [02:53<00:00,  1.43s/it, loss=3.46]


Epoch 5/15 - Mean Average Precision: 0.1024


Epoch 6/15: 100%|██████████| 121/121 [02:53<00:00,  1.44s/it, loss=3.3] 


Epoch 6/15 - Mean Average Precision: 0.1461


Epoch 7/15: 100%|██████████| 121/121 [02:53<00:00,  1.43s/it, loss=3.12]


Epoch 7/15 - Mean Average Precision: 0.1506


Epoch 8/15: 100%|██████████| 121/121 [02:53<00:00,  1.43s/it, loss=2.97]


Epoch 8/15 - Mean Average Precision: 0.1729


Epoch 9/15: 100%|██████████| 121/121 [02:53<00:00,  1.43s/it, loss=2.78]


Epoch 9/15 - Mean Average Precision: 0.1633


Epoch 10/15: 100%|██████████| 121/121 [02:53<00:00,  1.43s/it, loss=2.6] 


Epoch 10/15 - Mean Average Precision: 0.1381


Epoch 11/15: 100%|██████████| 121/121 [02:53<00:00,  1.44s/it, loss=2.47]


Epoch 11/15 - Mean Average Precision: 0.1637


Epoch 12/15: 100%|██████████| 121/121 [02:53<00:00,  1.43s/it, loss=2.26]


Epoch 12/15 - Mean Average Precision: 0.1937


Epoch 13/15: 100%|██████████| 121/121 [02:53<00:00,  1.43s/it, loss=2.07]


Epoch 13/15 - Mean Average Precision: 0.1700


Epoch 14/15: 100%|██████████| 121/121 [02:53<00:00,  1.43s/it, loss=1.81]


Epoch 14/15 - Mean Average Precision: 0.1641


Epoch 15/15: 100%|██████████| 121/121 [02:53<00:00,  1.44s/it, loss=1.58]


Epoch 15/15 - Mean Average Precision: 0.1813
Fine-tuning complete.


In [12]:
classifier.head, classifier.backbone

(Linear(in_features=2048, out_features=106, bias=True),
 ResNet(
   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
   (layer1): Sequential(
     (0): Bottleneck(
       (conv1): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
       (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (conv3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU(inplace=True)
     

In [13]:
torch.save(classifier.head, "resclassifier_head_v1_loss1.58_map0.1813.pt")

In [14]:
torch.save(classifier.backbone, "resclassifier_backbone_v1_loss1.58_map0.1813.pt")