In [1]:
%load_ext autoreload
%autoreload 2

In [16]:
import timm
import torch
from tqdm import tqdm
from sklearn import svm
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split
from PIL import Image
Image.MAX_IMAGE_PIXELS = None

from src.model.swin_transformer_v2_classifier import SwinTransformerV2Classifier
from src.data.visiomel_datamodule import VisiomelTrainDatamodule, VisiomelTrainDatamoduleSimMIM

In [10]:
def extract_features_single(model, x):
    x = model.patch_embed(x)
    if model.absolute_pos_embed is not None:
        x = x + model.absolute_pos_embed
    x = model.pos_drop(x)

    features = [x.mean(dim=1)]
    for layer in model.layers:
        x = layer(x)
        features.append(x.mean(dim=1))
    features[-1] = model.norm(features[-1])
    features = torch.cat(features, dim=1)

    return features

def extract_features(model, dataloader):
    features_all, y_all = [], []
    for batch in tqdm(dataloader):
        if len(batch) == 2:
            x, y = batch
        elif len(batch) == 3:
            x, mask, y = batch
        with torch.no_grad():
            x, y = x.cuda(), y.cuda()
            features = extract_features_single(model, x)
            features_all.append(features)
            y_all.append(y)

    features_all = torch.cat(features_all, dim=0)
    y_all = torch.cat(y_all, dim=0)

    return features_all, y_all

# Pretrained

In [2]:
model = timm.create_model(
    'swinv2_large_window12to24_192to384_22kft1k', 
    pretrained=True, 
    num_classes=0
).cuda().eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
P = 384

### No patches

In [None]:
datamodule = VisiomelTrainDatamodule(
    data_dir_train = '/workspace/data/images_page_7/',	
    k = 5,
    fold_index = 0,
    data_dir_test = None,
    img_size = 384,
    shrink_preview_scale = None,
    batch_size = 32,
    split_seed = 0,
    num_workers = 0,
    pin_memory = False,
    prefetch_factor = None,
    persistent_workers = False,
    sampler = None,
    data_shrinked=False,
    num_workers_saturated=10,
    enable_caching=False,
    train_resize_type='resize',
)
datamodule.setup()
train_dataloader = datamodule.train_dataloader()
val_dataloader, _ = datamodule.val_dataloader()

In [4]:
features_all, y_all = extract_features(model, train_dataloader)

100%|██████████| 34/34 [01:48<00:00,  3.20s/it]


In [7]:
features_all_val, y_all_val = extract_features(model, val_dataloader)

100%|██████████| 9/9 [00:25<00:00,  2.86s/it]


In [8]:
features_all.shape, y_all.shape

(torch.Size([1073, 4416]), torch.Size([1073]))

In [10]:
# SVM classifier on features_all and y_all
clf = svm.SVC(kernel='linear', C=1, probability=True).fit(features_all.cpu().numpy(), y_all.cpu().numpy())

# Predict on validation set
y_all_val_pred = clf.predict_proba(features_all_val.cpu().numpy())

# log_loss
log_loss(y_all_val.cpu().numpy(), y_all_val_pred, eps=1e-16)

0.4148100175016549

In [11]:
# Predict on train set
y_all_pred = clf.predict_proba(features_all.cpu().numpy())

# log_loss
log_loss(y_all.cpu().numpy(), y_all_pred, eps=1e-16)

0.2855574374923769

### Patches

In [5]:
def extract_features_patches(model, dataloader):
    features_all, y_all = [], []
    for x, y in tqdm(dataloader):
        with torch.no_grad():
            x, y = x.cuda(), y.cuda()

            B, C, H, W = x.shape
            P = 384

            # Pad image to be divisible by 384
            x = torch.nn.functional.pad(x, (0, P - x.shape[3] % P, 0, P - x.shape[2] % P), mode='reflect')

            # Extract patches
            x = x \
                .unfold(2, P, P) \
                .unfold(3, P, P)

            # Concat to batch dimension
            # (B, C, H_patches, W_patches, patch_size, patch_size) -> 
            # (B * H_patches * W_patches, C, patch_size, patch_size)
            H_patches, W_patches = x.shape[2:4]
            x = x.reshape(B * H_patches * W_patches, C, P, P)

            # Extract features
            features = extract_features_single(model, x)
            print(features.shape)

            # Reshape back as expected from PatchEmbed
            # (B * H_patches * W_patches, *backbone_out_shape) ->
            # (B, H * W, *backbone_out_shape)
            x = x.reshape(B, H_patches * W_patches, *x.shape[1:])
            
            features_all.append(features)
            y_all.append(y)

    features_all = torch.cat(features_all, dim=0)
    y_all = torch.cat(y_all, dim=0)

    return features_all, y_all

In [6]:
datamodule = VisiomelTrainDatamodule(
    data_dir_train = '/workspace/data/images_page_7/',	
    k = 5,
    fold_index = 0,
    data_dir_test = None,
    img_size = 384 * 2,
    shrink_preview_scale = None,
    batch_size = 4,
    split_seed = 0,
    num_workers = 0,
    pin_memory = False,
    prefetch_factor = None,
    persistent_workers = False,
    sampler = None,
    data_shrinked=False,
    num_workers_saturated=10,
    enable_caching=False,
    train_resize_type='resize',
)
datamodule.setup()
train_dataloader = datamodule.train_dataloader()
val_dataloader, _ = datamodule.val_dataloader()

In [7]:
features_all, y_all = extract_features_patches(model, train_dataloader)

100%|██████████| 269/269 [14:43<00:00,  3.28s/it]


In [8]:
features_all_val, y_all_val = extract_features_patches(model, val_dataloader)

100%|██████████| 68/68 [04:13<00:00,  3.73s/it]


In [9]:
features_all.shape, y_all.shape

(torch.Size([9657, 4416]), torch.Size([1073]))

In [13]:
features_all.reshape(y_all.shape[0], features_all.shape[1], -1).flatten(1).shape

torch.Size([1073, 39744])

In [15]:
features_all = features_all.reshape(y_all.shape[0], features_all.shape[1], -1).flatten(1)
features_all_val = features_all_val.reshape(y_all_val.shape[0], features_all_val.shape[1], -1).flatten(1)

In [18]:
features_all.shape

torch.Size([1073, 39744])

In [16]:
# SVM classifier on features_all and y_all
clf = svm.SVC(kernel='linear', C=1, probability=True).fit(features_all.cpu().numpy(), y_all.cpu().numpy())

# Predict on validation set
y_all_val_pred = clf.predict_proba(features_all_val.cpu().numpy())

# log_loss
log_loss(y_all_val.cpu().numpy(), y_all_val_pred, eps=1e-16)

0.40077315802046426

In [17]:
# Predict on train set
y_all_pred = clf.predict_proba(features_all.cpu().numpy())

# log_loss
log_loss(y_all.cpu().numpy(), y_all_pred, eps=1e-16)

0.24109060020013015

# SSUP

In [4]:
state_dict = {
    k \
        .replace('model.encoder.', 'model.'): v 
    for k, v in 
    torch.load('/workspace/visiomel-2023/visiomel/eahzd753/checkpoints/last.ckpt')['state_dict'].items()
}
model = SwinTransformerV2Classifier(
    model_name='swinv2_large_window12to24_192to384_22kft1k', 
    num_classes=0, 
    img_size=1536, 
    patch_size=4,
)
model.load_state_dict(state_dict, strict=False)
model = model.cuda()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
_IncompatibleKeys(missing_keys=['layers.0.blocks.1.attn_mask', 'layers.1.blocks.1.attn_mask', 'layers.2.blocks.1.attn_mask', 'layers.2.blocks.3.attn_mask', 'layers.2.blocks.5.attn_mask', 'layers.2.blocks.7.attn_mask', 'layers.2.blocks.9.attn_mask', 'layers.2.blocks.11.attn_mask', 'layers.2.blocks.13.attn_mask', 'layers.2.blocks.15.attn_mask', 'layers.2.blocks.17.attn_mask', 'layers.3.blocks.1.attn_mask'], unexpected_keys=[])


In [5]:
datamodule = VisiomelTrainDatamoduleSimMIM(
    data_dir_train = '/workspace/data/images_page_4_shrink/',	
    data_dir_test = None,
    img_size = 1536,
    shrink_preview_scale = None,
    batch_size = 2,
    split_seed = 0,
    num_workers = 4,
    pin_memory = False,
    prefetch_factor = None,
    persistent_workers = False,
    sampler = None,
    data_shrinked=True,
    num_workers_saturated=4,
    enable_caching=False,
    train_resize_type='resize',
)
datamodule.setup()
train_dataloader = datamodule.train_dataloader()

In [8]:
def extract_features_single_ssup(model, x, mask):
    x = model.patch_embed(x)
    if model.absolute_pos_embed is not None:
        x = x + model.absolute_pos_embed
    x = model.pos_drop(x)

    features = [x.mean(dim=1)]
    for layer in model.layers:
        x = layer(x)
        features.append(x.mean(dim=1))
    features[-1] = model.norm(features[-1])
    features = torch.cat(features, dim=1)

    return features

def extract_features_ssup(model, dataloader):
    features_all, y_all = [], []
    for x, mask, y in tqdm(dataloader):
        with torch.no_grad():
            x, y = x.cuda(), y.cuda()
            features = extract_features_single_ssup(model, x, mask)
            features_all.append(features)
            y_all.append(y)

    features_all = torch.cat(features_all, dim=0)
    y_all = torch.cat(y_all, dim=0)

    return features_all, y_all

In [13]:
features_all, y_all = extract_features(model.model, train_dataloader)

100%|██████████| 671/671 [10:28<00:00,  1.07it/s]


In [14]:
features_all.shape, y_all.shape

(torch.Size([1342, 4416]), torch.Size([1342]))

In [17]:
features_all, features_all_val, y_all, y_all_val = train_test_split(features_all, y_all, test_size=0.2, random_state=0)

In [18]:
# SVM classifier on features_all and y_all
clf = svm.SVC(kernel='linear', C=1, probability=True).fit(features_all.cpu().numpy(), y_all.cpu().numpy())

# Predict on validation set
y_all_val_pred = clf.predict_proba(features_all_val.cpu().numpy())

# log_loss
log_loss(y_all_val.cpu().numpy(), y_all_val_pred, eps=1e-16)

0.35908364195215864

In [19]:
# Predict on train set
y_all_pred = clf.predict_proba(features_all.cpu().numpy())

# log_loss
log_loss(y_all.cpu().numpy(), y_all_pred, eps=1e-16)

0.28846559331031174