In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models, transforms, datasets
from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score, precision_score, f1_score, recall_score
from torchvision.models import ResNet50_Weights
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import numpy as np
import itertools
from tqdm import tqdm
import torch.optim.lr_scheduler as lr_scheduler
import warnings
import pandas as pd
from early_stopping import EarlyStopping
import multiprocessing
import seaborn as sns
import matplotlib.pyplot as plt
import datetime
import numbers
import os
import time
from torch.utils.data import random_split

# Suppress specific warnings
warnings.filterwarnings("ignore", message="Mean of empty slice")
warnings.filterwarnings("ignore", message="invalid value encountered in scalar divide")
warnings.simplefilter(action='ignore', category=FutureWarning)


  from .autonotebook import tqdm as notebook_tqdm


In [20]:
dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg_lc')

Using cache found in C:\Users\Csabi/.cache\torch\hub\facebookresearch_dinov2_main


In [5]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

ds = datasets.ImageFolder('../data/combined_data', transform=train_transform)

generator = torch.Generator().manual_seed(42)
splits = random_split(ds, [0.8, 0.1, 0.1], generator=generator)

train_loader = DataLoader(splits[0], batch_size=1, shuffle=True)


In [8]:
for idx, (images, labels) in enumerate(train_loader):
    output = dino(images)
    print(output.shape)
    print(output)
    break

torch.Size([1, 1000])
tensor([[-2.1194e+00,  1.1930e+00,  2.8942e+00,  3.9242e+00, -1.1909e+00,
          4.4121e+00,  3.5639e+00, -1.9800e+00, -2.9205e+00, -1.1568e+00,
         -2.4510e+00, -1.0676e+00, -1.4491e+00, -3.7137e+00, -3.0774e+00,
          9.4078e-02, -2.9535e+00, -2.7333e+00, -3.7692e-01, -3.0447e+00,
         -1.5742e+00, -2.3544e+00, -2.0268e+00, -2.8583e+00,  7.3290e-01,
         -1.6890e+00,  4.1905e-01,  3.3083e-01, -1.4263e+00,  2.3497e+00,
          1.5611e-02, -1.0656e+00,  1.2170e+00, -1.0723e+00,  4.6177e-01,
          1.1114e+00, -2.6614e+00, -3.6995e+00,  2.0145e-01,  4.6400e-01,
         -1.9924e+00, -1.1615e-02, -7.9526e+00, -2.8874e+00, -3.7541e+00,
         -1.4198e+00, -1.7141e+00, -1.2801e+00, -2.6901e+00, -1.9316e+00,
          7.5068e-01, -9.1190e-01,  2.8815e+00,  1.0715e+00,  2.4640e+00,
          1.5590e+00, -1.8430e+00, -3.5387e+00,  2.8950e+00,  3.1062e+00,
          4.4947e+00,  1.7417e+00,  1.1575e+00,  5.7030e-01,  9.4455e-01,
          1.8285

In [21]:
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
#for param in model.parameters():
#    param.requires_grad = False
for name, param in model.named_parameters():
    if "fc" in name:
        param.requires_grad = False

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(ds.classes))
model.fc.requires_grad = True

In [60]:
import copy


class DinoRes(nn.Module):
    def __init__(self, resnet, dino, num_classes, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.dino = copy.deepcopy(dino)
        dino_feat = dino.linear_head.in_features
        self.dino.linear_head = nn.Identity()
        self.resnet = copy.deepcopy(resnet)
        num_features = resnet.fc.in_features
        self.resnet.fc = nn.Identity()
        self.fc = nn.Linear(num_features+dino_feat, num_classes)

    def forward(self, x):
        dino_features = self.dino(x)
        out = torch.cat((self.resnet(x), dino_features), 1)
        out = self.fc(out)
        return out

In [61]:
dino

_LinearClassifierWrapper(
  (backbone): DinoVisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
      (norm): Identity()
    )
    (blocks): ModuleList(
      (0-11): 12 x NestedTensorBlock(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): MemEffAttention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): LayerScale()
        (drop_path1): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (drop): Dropout(p=0.0, inplace=Fals

In [63]:
dres = DinoRes(model, dino, 4)

In [64]:
for idx, (images, labels) in enumerate(train_loader):
    output = dres(images)
    print(output.shape)
    print(output)
    break

torch.Size([1, 4])
tensor([[-1.2127, -0.3822,  0.0854, -0.3054]], grad_fn=<AddmmBackward0>)
