In [1]:
import timm
import torch
from tqdm import tqdm
from PIL import Image
from sklearn import svm
from sklearn.metrics import log_loss
Image.MAX_IMAGE_PIXELS = None

from src.data.visiomel_datamodule import VisiomelTrainDatamodule

In [2]:
datamodule = VisiomelTrainDatamodule(
    data_dir_train = '/workspace/data/images_page_7/',	
    k = 5,
    fold_index = 0,
    data_dir_test = None,
    img_size = 384,
    shrink_preview_scale = None,
    batch_size = 32,
    split_seed = 0,
    num_workers = 0,
    pin_memory = False,
    prefetch_factor = None,
    persistent_workers = False,
    sampler = None,
    data_shrinked=False,
    num_workers_saturated=10,
    enable_caching=False,
    train_resize_type='resize',
)
datamodule.setup()
train_dataloader = datamodule.train_dataloader()
val_dataloader, _ = datamodule.val_dataloader()

In [3]:
model = timm.create_model(
    'swinv2_large_window12to24_192to384_22kft1k', 
    pretrained=True, 
    num_classes=0
).cuda().eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [6]:
def extract_features(dataloader):
    features_all, y_all = [], []
    for x, y in tqdm(dataloader):
        with torch.no_grad():
            x, y = x.cuda(), y.cuda()
            
            x = model.patch_embed(x)
            if model.absolute_pos_embed is not None:
                x = x + model.absolute_pos_embed
            x = model.pos_drop(x)

            features = [x.mean(dim=1)]
            for layer in model.layers:
                x = layer(x)
                features.append(x.mean(dim=1))
            features[-1] = model.norm(features[-1])
            features = torch.cat(features, dim=1)
            
            features_all.append(features)
            y_all.append(y)

    features_all = torch.cat(features_all, dim=0)
    y_all = torch.cat(y_all, dim=0)

    return features_all, y_all

In [4]:
features_all, y_all = extract_features(train_dataloader)

100%|██████████| 34/34 [01:48<00:00,  3.20s/it]


In [7]:
features_all_val, y_all_val = extract_features(val_dataloader)

100%|██████████| 9/9 [00:25<00:00,  2.86s/it]


In [8]:
features_all.shape, y_all.shape

(torch.Size([1073, 4416]), torch.Size([1073]))

In [10]:
# SVM classifier on features_all and y_all
clf = svm.SVC(kernel='linear', C=1, probability=True).fit(features_all.cpu().numpy(), y_all.cpu().numpy())

# Predict on validation set
y_all_val_pred = clf.predict_proba(features_all_val.cpu().numpy())

# log_loss
log_loss(y_all_val.cpu().numpy(), y_all_val_pred, eps=1e-16)

0.4148100175016549

In [11]:
# Predict on train set
y_all_pred = clf.predict_proba(features_all.cpu().numpy())

# log_loss
log_loss(y_all.cpu().numpy(), y_all_pred, eps=1e-16)

0.2855574374923769