In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import timm
import torch
from tqdm import tqdm

from model.patch_embed_with_backbone import PatchBackbone
from data.visiomel_datamodule import VisiomelDatamodule

In [3]:
patch_embed_backbone_name = 'swinv2_base_window12to24_192to384_22kft1k'
patch_size = 1536
patch_embed_backbone_ckpt_path = '/workspace/visiomel-2023/weights/val_ssup_patches_aug/checkpoints/last.ckpt'
patch_batch_size = 5
batch_size = 1
train_transform_n_repeats = 5
save_path = '/workspace/visiomel-2023/weights/val_ssup_patches_aug/embeddings/'

Data

In [4]:
datamodule = VisiomelDatamodule(
    task='simmim_randaug',
    data_dir_train='/workspace/data/images_page_4_shrink/',	
    k=5,
    fold_index=0,
    data_dir_test=None,
    img_size=patch_size,
    shrink_preview_scale=None,
    batch_size=batch_size,
    split_seed=0,
    num_workers=4,
    num_workers_saturated=4,
    pin_memory=False,
    prefetch_factor=None,
    persistent_workers=True,
    sampler=None,
    enable_caching=False,
    data_shrinked=True,
    train_resize_type='none',
    train_transform_n_repeats=train_transform_n_repeats,
)
datamodule.setup()
train_dataloader = datamodule.train_dataloader()
val_dataloader, _ = datamodule.val_dataloader()

Model

In [5]:
backbone = timm.create_model(
    patch_embed_backbone_name, 
    img_size=patch_size, 
    pretrained=False, 
    num_classes=0
)
if patch_embed_backbone_ckpt_path is not None:
    # If backbone is fine-tuned then it is done via SwinTransformerV2SimMIM
    # module, so we need to remove the prefix 'model.encoder.' from the
    # checkpoint state_dict keys.
    state_dict = {
        k \
            .replace('model.encoder.', 'model.'): v 
        for k, v in 
        torch.load(patch_embed_backbone_ckpt_path)['state_dict'].items()
    }
    backbone.load_state_dict(state_dict, strict=False)

patch_embed = PatchBackbone(
    backbone=backbone, 
    patch_size=patch_size, 
    embed_dim=backbone.num_features,
    patch_batch_size=patch_batch_size,
    patch_embed_caching=False,
).cuda().eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [6]:
with torch.no_grad():
    features, labels, paths = [], [], []
    for batch in tqdm(train_dataloader):  # batches
        x_minibatch, mask_minibatch, y_minibatch, path_minibatch = batch
        for x, mask, y, path in zip(x_minibatch, mask_minibatch, y_minibatch, path_minibatch):  # samples
            features.append(patch_embed(x.unsqueeze(0).cuda()).detach().cpu())
            labels.append(y.detach().cpu())
            paths.append(path)

100%|██████████| 1073/1073 [5:54:06<00:00, 19.80s/it] 


In [8]:
df_train = pd.DataFrame({
    'path': paths,
    'label': labels,
    'features': features,
})
df_train.to_pickle(save_path + 'train.pkl')

In [6]:
with torch.no_grad():
    features, labels, paths = [], [], []
    for batch in tqdm(val_dataloader):  # batches
        x_minibatch, mask_minibatch, y_minibatch, path_minibatch = batch
        for x, mask, y, path in zip(x_minibatch, mask_minibatch, y_minibatch, path_minibatch):  # samples
            features.append(patch_embed(x.unsqueeze(0).cuda()).detach().cpu())
            labels.append(y.detach().cpu())
            paths.append(path)

100%|██████████| 269/269 [1:29:48<00:00, 20.03s/it]


In [7]:
df_val = pd.DataFrame({
    'path': paths,
    'label': labels,
    'features': features,
})
df_val.to_pickle(save_path + 'val_aug.pkl')