# Import 

In [2]:
!pip install einops
!git clone https://github.com/b-ptiste/dlmi.git

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)
Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [3]:
# Standard library imports
import os
import random
import copy
import time

# Related third-party imports
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import KFold
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torchvision import transforms as T
from torchvision.transforms import v2
from tqdm import tqdm
import timm
import wandb
import uuid
from sklearn.metrics import balanced_accuracy_score

# log in different framework
path_root = "/kaggle/input/dlmi-challenge-b-and-s"
path_working = "/kaggle/working"
path_mae = "/kaggle/input/mae-pretrain"


# local library
from dlmi.src.model import ModelFactory
from dlmi.src.data import csv_processing, DataloaderFactory
from dlmi.src.utils import get_stratified_split
from dlmi.src.mae_pretraining import MAE_ViT, MAE_Encoder, MAE_Decoder, PatchShuffle
from dlmi.data.split import train_index as train_index_strat
from dlmi.data.split import val_index as val_index_strat

Cloning into 'dlmi'...
remote: Enumerating objects: 227, done.[K
remote: Counting objects: 100% (17/17), done.[K
remote: Compressing objects: 100% (16/16), done.[K
remote: Total 227 (delta 4), reused 3 (delta 0), pack-reused 210[K
Receiving objects: 100% (227/227), 45.05 KiB | 490.00 KiB/s, done.
Resolving deltas: 100% (84/84), done.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


# Create config.

This config contains all the hyparameter usefull for our experiments. There will be logged in wandb.

The weight for the MAE pretraining are avalaible and need to be download here : 

Drive with the weigth [here](https://drive.google.com/drive/u/0/folders/13yrd36hwnCahIzXtedJdakCQZdADHxLd)

In [4]:
cfg = {
    "who": "baptiste",  # or steven
    "no_wandb": False,
    "name_exp": "PatientModelCrossAttentionTab - vit_small_patch16_224 - lora",
    "lr": 5e-6,
    "batch_size": 1,
    "nb_epochs": 20,
    "timm": True,  # is the model from timm
    "timm_model": "vit_small_patch16_224.augreg_in21k",
    "dino": False,
    "dino_size": "vits",  # vits, vitb, vitl, vitg
    "adapter": "lora",  # bottleneck, adaptformer, lora, prompttuning
    "model_name": "PatientModelCrossAttentionTab",  # 'vit_small_patch16_224.augreg_in21k', #timm based model
    "pretrained": True,
    "pretrained_path": "",
    "nb_class": 2,
    "scheduler": None,  # could be empty or linear, expo ...
    "dataset_name": "DatasetPerPatient",
    "device_1": "cuda:0",
    "device_2": "cuda:1",  # for double device
    # data augmentation
    "filename": f"{path_working}/submission.csv",
    "sub_batch_size": 16,
    "latent_att": 512,
    "head_1": 8,  # 4
    "head_2": 2,
    "feature_dim": 384,  # DINOv2, VIT: 192 - 384
    "aggregation": "avg",  # sum, avg, max
    "beta_1": 0.5,
    "beta_2": 0.9,
    "weight_decay": 5e-2,
    "weight_class_0": 3.0,
    "weight_class_1": 1.0,
    "mask_ratio": 0.75,
    "image_size": 224,
    "patch_size": 16,
    "mae_pretrained": "small_testset_800it.pt",
    "with_tab": True,
    "mode_split": "strat",  # load, strat
    "degrees": (-5, 5),
    "translate": (0.1, 0.1),
    "scale": (0.9, 1.0),
    "fill": (255, 232, 201),
    "p": 0.1,
}

# Data importation

Csv import with pre-processing, reformatting and normalisation.

In [5]:
df_annotation_train = csv_processing(
    os.path.join(path_root, "trainset", "trainset_true.csv")
)
df_annotation_test = csv_processing(
    os.path.join(path_root, "testset", "testset_data.csv")
)

Create train, val, test dataset.

In [6]:
test_index = df_annotation_test.index.tolist()

if cfg["mode_split"] == "auto":
    map_mode_index = get_stratified_split(df_annotation_train, df_annotation_test)

    train_index = map_mode_index["train"]
    val_index = map_mode_index["val"]

elif cfg["mode_split"] == "load":
    # log wandb
    run = wandb.init()
    artifact = run.use_artifact(
        "ii_timm/DLMI/submission958f5028e70811ee9d6b0242ac130202:v0", type="csv"
    )
    artifact_dir = artifact.download(root=path_working)
    wandb.finish()

    train_index = pd.read_csv(f"{path_working}/train_index.csv")[
        "train"
    ].values.tolist()
    val_index = pd.read_csv(f"{path_working}/val_index.csv")["val"].values.tolist()

elif cfg["mode_split"] == "strat":
    train_index = train_index_strat
    val_index = val_index_strat

# Finetuning

In [7]:
# transform
transform_train = T.Compose(
    [
        v2.PILToTensor(),
        v2.RandomHorizontalFlip(p=cfg["p"]),
        v2.RandomVerticalFlip(p=cfg["p"]),
        v2.RandomAffine(
            degrees=cfg["degrees"],
            translate=cfg["translate"],
            scale=cfg["scale"],
            fill=cfg["fill"],
        ),
    ]
)

transform_val = T.Compose(
    [
        v2.PILToTensor(),
    ]
)

In [8]:
# load data
data_factory = DataloaderFactory()
model_factory = ModelFactory()
dataloader_train = data_factory(
    cfg,
    mode="train",
    split_indexes=train_index,
    path_root=path_root,
    shuffle=True,
    drop_last=True,
    transform=transform_train,
    oversampling={"0": 1, "1": 1},
)

dataloader_val = data_factory(
    cfg,
    mode="train",
    split_indexes=val_index,
    path_root=path_root,
    shuffle=False,
    drop_last=False,
    transform=transform_val,
    oversampling={"0": 1, "1": 1},
)


# load model
model = model_factory(cfg).to(cfg["device_1"])

# optimizer and scheduler
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=cfg["lr"],
    betas=(cfg["beta_1"], cfg["beta_2"]),
    weight_decay=cfg["weight_decay"],
)

scheduler = None  # torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg['nb_epochs'], eta_min=5e-6)

soft_max = torch.nn.Softmax(1)


cfg["nb_params_train"] = sum(p.numel() for p in model.parameters() if p.requires_grad)

cfg["nb_params_tot"] = sum(p.numel() for p in model.parameters())

print("=" * 50)
print(f'The model has {cfg["nb_params_tot"]} parameters')
print(f'The model has {cfg["nb_params_train"]} trainable parameters')
print(
    f'It represents {np.round(100 * cfg["nb_params_train"]/cfg["nb_params_tot"], 3)} % trainable parameters'
)
print("=" * 50)

The configuration is:
who : baptiste
no_wandb : False
name_exp : PatientModelCrossAttentionTab - vit_small_patch16_224 - lora
lr : 5e-06
batch_size : 1
nb_epochs : 20
timm : True
timm_model : vit_small_patch16_224.augreg_in21k
dino : False
dino_size : vits
adapter : lora
model_name : PatientModelCrossAttentionTab
pretrained : True
pretrained_path : 
nb_class : 2
scheduler : None
dataset_name : DatasetPerPatient
device_1 : cuda:0
device_2 : cuda:1
filename : /kaggle/working/submission.csv
sub_batch_size : 16
latent_att : 512
head_1 : 8
head_2 : 2
feature_dim : 384
aggregation : avg
beta_1 : 0.5
beta_2 : 0.9
weight_decay : 0.05
weight_class_0 : 3.0
weight_class_1 : 1.0
mask_ratio : 0.75
image_size : 224
patch_size : 16
mae_pretrained : small_testset_800it.pt
with_tab : True
mode_split : strat
degrees : (-5, 5)
translate : (0.1, 0.1)
scale : (0.9, 1.0)
fill : (255, 232, 201)
p : 0.1
Loading custom model PatientModelCrossAttentionTab


model.safetensors:   0%|          | 0.00/120M [00:00<?, ?B/s]

The training is from scatch
Use lora adapter
The model has 23505374 parameters
The model has 1839710 trainable parameters
It represents 7.827 % trainable parameters


In [9]:
# Load the entire model
if len(cfg["mae_pretrained"]) > 0:
    MAE_model = torch.load(os.path.join(path_mae, cfg["mae_pretrained"]))
    model.blocks = copy.deepcopy(MAE_model.encoder.model.blocks)

In [10]:
weight = torch.tensor([2.5, 1.0]).to("cuda:0")
loss_fn = torch.nn.CrossEntropyLoss(weight=weight)

In [11]:
#############################################
###              Training
#############################################

best_loss = 10000

if not cfg["no_wandb"]:
    run = wandb.init(
        project="DLMI",
        entity="ii_timm",
        name=cfg["name_exp"],
        config=cfg,
    )


print("Start Training ...")
for epoch in range(cfg["nb_epochs"]):
    model.train()
    print("=" * 50)
    print(" " * 15, f"Epoch {epoch}")
    print("=" * 50)

    train_cum_loss = 0
    start_time = time.time()

    #############################
    ###     VAL loop
    #############################
    train_pred = []
    train_label = []

    for x, annotation in tqdm(dataloader_train):
        optimizer.zero_grad()
        x = x.to(cfg["device_1"]).squeeze(0)

        if cfg["with_tab"]:
            # define tabular data
            lymph_count, age, bin_gender = (
                annotation["LYMPH_COUNT"],
                annotation["AGE"],
                annotation["BIN_GENDER"],
            )
            x_tab = torch.zeros((1, 4)).to(cfg["device_1"])
            x_tab[0, int(bin_gender)] = 1
            x_tab[0, 2] = torch.clamp(age + 1e-4 * np.random.rand(1)[0], 0, 1)
            x_tab[0, 3] = torch.clamp(lymph_count + 1e-4 * np.random.rand(1)[0], 0, 1)

            xout_sub_batch = model(x, x_tab, "train")
        else:
            # None tabular data
            xout_sub_batch = model(x, "train")

        # compute the loss and pred
        loss = loss_fn(
            xout_sub_batch.unsqueeze(0), annotation["LABEL"].to(cfg["device_1"])
        ) / (x.shape[0] / cfg["sub_batch_size"])
        pred = torch.argmax(soft_max(xout_sub_batch.unsqueeze(0)), dim=1)
        train_cum_loss += loss.item()

        # store the res.
        train_pred.extend(pred.detach().cpu().tolist())
        train_label.extend(annotation["LABEL"].detach().cpu().tolist())

        # backward
        loss.backward()
        optimizer.step()

        if scheduler is not None:
            scheduler.step()

    # compute agg. scores
    train_balance_acc = balanced_accuracy_score(train_pred, train_label)
    train_avg_loss = train_cum_loss / len(dataloader_train)
    print(
        f"train_balance_acc : {np.round(train_balance_acc, 6)} / train_avg_loss : {np.round(train_avg_loss, 6)}"
    )
    unique_train, count_train = np.unique(train_pred, return_counts=True)
    print(unique_train, count_train)

    # edge case
    if len(unique_train) == 1:
        if unique_train[0] == 1:
            count_train = count_train.tolist()
            count_train.insert(0, 0)
        if unique_train[0] == 0:
            count_train = count_train.tolist()
            count_train.append(0)

    #############################
    ###     VAL loop
    #############################
    val_pred = []
    val_label = []
    val_cum_loss = 0
    model.eval()
    for x, annotation in tqdm(dataloader_val):
        # forward
        x = x.to(cfg["device_1"]).squeeze(0)

        if cfg["with_tab"]:
            # define tabular data
            lymph_count, age, bin_gender = (
                annotation["LYMPH_COUNT"],
                annotation["AGE"],
                annotation["BIN_GENDER"],
            )
            x_tab = torch.zeros((1, 4)).to(cfg["device_1"])
            x_tab[0, int(bin_gender)] = 1
            x_tab[0, 2] = age
            x_tab[0, 3] = lymph_count

            xout_sub_batch = model(x, x_tab, "val")
        else:
            xout_sub_batch = model(x, "val")
        # compute loss
        loss = loss_fn(
            xout_sub_batch.unsqueeze(0), annotation["LABEL"].to(cfg["device_1"])
        ) / (x.shape[0] / cfg["sub_batch_size"])
        pred = torch.argmax(soft_max(xout_sub_batch.unsqueeze(0)), dim=1)
        val_cum_loss += loss.item()
        val_pred.extend(pred.detach().cpu().tolist())
        val_label.extend(annotation["LABEL"].detach().cpu().tolist())

    # compute agg. scores
    val_balance_acc = balanced_accuracy_score(val_pred, val_label)
    val_avg_loss = val_cum_loss / len(dataloader_val)
    print(
        f"val_balance_acc : {np.round(val_balance_acc, 6)} / val_avg_loss : {np.round(val_avg_loss, 6)}"
    )
    unique_val, count_val = np.unique(val_pred, return_counts=True)
    print(unique_val, count_val)

    # edge case
    if len(unique_val) == 1:
        if unique_val[0] == 1:
            count_val = count_val.tolist()
            count_val.insert(0, 0)
        if unique_val[0] == 0:
            count_val = count_val.tolist()
            count_val.append(0)
    print(val_avg_loss, best_loss)

    # Save best model + prints
    if val_avg_loss < best_loss:
        best_loss = val_avg_loss

        print("Improve avg loss :")
        save_path_finetune = os.path.join("./", "model" + str(epoch) + "_finetune.pt")
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict()
                if cfg["scheduler"] is not None
                else None,
            },
            save_path_finetune,
        )
        print("checkpoint saved to: {}".format(save_path_finetune))

    print(
        "time",
        (time.time() - start_time) / (len(dataloader_val) + len(dataloader_train)),
    )

    # Save in Wandb
    if not cfg["no_wandb"]:
        wandb.log(
            {
                "epoch": epoch,
                "balance_acc/train": train_balance_acc,
                "loss/train": train_avg_loss,
                "balance_acc/val": val_balance_acc,
                "loss/val": val_avg_loss,
                "time": (time.time() - start_time)
                / (len(dataloader_val) + len(dataloader_train)),
                "count_train_0": count_train[0],
                "count_train_1": count_train[1],
                "count_val_0": count_val[0],
                "count_val_1": count_val[1],
            }
        )

[34m[1mwandb[0m: Currently logged in as: [33mbaptcallard[0m ([33mii_timm[0m). Use [1m`wandb login --relogin`[0m to force relogin


Start Training ...
                Epoch 0


100%|██████████| 130/130 [02:14<00:00,  1.03s/it]


train_balance_acc : 0.692308 / train_avg_loss : 0.17662
[1] [130]


100%|██████████| 33/33 [00:22<00:00,  1.45it/s]


val_balance_acc : 0.69697 / val_avg_loss : 0.160592
[1] [33]
0.16059231633941332 10000
Improve avg loss :
checkpoint saved to: ./model0_finetune.pt
time 0.9652812466299607
                Epoch 1


100%|██████████| 130/130 [01:29<00:00,  1.46it/s]


train_balance_acc : 0.692308 / train_avg_loss : 0.175508
[1] [130]


100%|██████████| 33/33 [00:11<00:00,  2.81it/s]


val_balance_acc : 0.69697 / val_avg_loss : 0.160385
[1] [33]
0.16038476472551172 0.16059231633941332
Improve avg loss :
checkpoint saved to: ./model1_finetune.pt
time 0.6218662949427505
                Epoch 2


100%|██████████| 130/130 [01:28<00:00,  1.48it/s]


train_balance_acc : 0.692308 / train_avg_loss : 0.176352
[1] [130]


100%|██████████| 33/33 [00:11<00:00,  2.94it/s]


val_balance_acc : 0.69697 / val_avg_loss : 0.159445
[1] [33]
0.159444641999223 0.16038476472551172
Improve avg loss :
checkpoint saved to: ./model2_finetune.pt
time 0.6113074469420076
                Epoch 3


100%|██████████| 130/130 [01:27<00:00,  1.48it/s]


train_balance_acc : 0.692308 / train_avg_loss : 0.175708
[1] [130]


100%|██████████| 33/33 [00:11<00:00,  2.90it/s]


val_balance_acc : 0.69697 / val_avg_loss : 0.154394
[1] [33]
0.1543941211068269 0.159444641999223
Improve avg loss :
checkpoint saved to: ./model3_finetune.pt
time 0.6093163753579731
                Epoch 4


100%|██████████| 130/130 [01:27<00:00,  1.49it/s]


train_balance_acc : 0.344961 / train_avg_loss : 0.174251
[0 1] [  1 129]


100%|██████████| 33/33 [00:11<00:00,  2.86it/s]


val_balance_acc : 0.69697 / val_avg_loss : 0.153844
[1] [33]
0.15384402667934244 0.1543941211068269
Improve avg loss :
checkpoint saved to: ./model4_finetune.pt
time 0.607740064340135
                Epoch 5


100%|██████████| 130/130 [01:28<00:00,  1.47it/s]


train_balance_acc : 0.567734 / train_avg_loss : 0.168104
[0 1] [ 14 116]


100%|██████████| 33/33 [00:11<00:00,  2.85it/s]


val_balance_acc : 0.870968 / val_avg_loss : 0.137757
[0 1] [ 2 31]
0.13775692078651805 0.15384402667934244
Improve avg loss :
checkpoint saved to: ./model5_finetune.pt
time 0.6155322899847674
                Epoch 6


100%|██████████| 130/130 [01:29<00:00,  1.46it/s]


train_balance_acc : 0.590475 / train_avg_loss : 0.163825
[0 1] [ 29 101]


100%|██████████| 33/33 [00:11<00:00,  2.77it/s]


val_balance_acc : 0.69697 / val_avg_loss : 0.137474
[1] [33]
0.13747364491450065 0.13775692078651805
Improve avg loss :
checkpoint saved to: ./model6_finetune.pt
time 0.6223686633665869
                Epoch 7


100%|██████████| 130/130 [01:27<00:00,  1.48it/s]


train_balance_acc : 0.679242 / train_avg_loss : 0.158613
[0 1] [ 29 101]


100%|██████████| 33/33 [00:11<00:00,  2.86it/s]


val_balance_acc : 0.69697 / val_avg_loss : 0.142136
[1] [33]
0.1421360850447055 0.13747364491450065
time 0.6098025093780705
                Epoch 8


100%|██████████| 130/130 [01:29<00:00,  1.46it/s]


train_balance_acc : 0.69 / train_avg_loss : 0.159351
[0 1] [ 30 100]


100%|██████████| 33/33 [00:11<00:00,  2.91it/s]


val_balance_acc : 0.8775 / val_avg_loss : 0.116883
[0 1] [ 8 25]
0.11688325796840769 0.13747364491450065
Improve avg loss :
checkpoint saved to: ./model8_finetune.pt
time 0.61893706526493
                Epoch 9


100%|██████████| 130/130 [01:30<00:00,  1.44it/s]


train_balance_acc : 0.710459 / train_avg_loss : 0.155474
[0 1] [32 98]


100%|██████████| 33/33 [00:11<00:00,  2.77it/s]


val_balance_acc : 0.69697 / val_avg_loss : 0.143872
[1] [33]
0.1438716763461178 0.11688325796840769
time 0.6260636086844228
                Epoch 10


100%|██████████| 130/130 [01:29<00:00,  1.45it/s]


train_balance_acc : 0.719413 / train_avg_loss : 0.151164
[0 1] [37 93]


100%|██████████| 33/33 [00:11<00:00,  2.83it/s]


val_balance_acc : 0.896552 / val_avg_loss : 0.112214
[0 1] [ 4 29]
0.11221387954146574 0.11688325796840769
Improve avg loss :
checkpoint saved to: ./model10_finetune.pt
time 0.6237872012553771
                Epoch 11


100%|██████████| 130/130 [01:32<00:00,  1.41it/s]


train_balance_acc : 0.767435 / train_avg_loss : 0.140053
[0 1] [36 94]


100%|██████████| 33/33 [00:12<00:00,  2.72it/s]


val_balance_acc : 0.886364 / val_avg_loss : 0.103265
[0 1] [11 22]
0.10326516941528428 0.11221387954146574
Improve avg loss :
checkpoint saved to: ./model11_finetune.pt
time 0.6434318770660213
                Epoch 12


100%|██████████| 130/130 [01:34<00:00,  1.38it/s]


train_balance_acc : 0.700523 / train_avg_loss : 0.13576
[0 1] [37 93]


100%|██████████| 33/33 [00:14<00:00,  2.26it/s]


val_balance_acc : 0.69697 / val_avg_loss : 0.142303
[1] [33]
0.14230306974301735 0.10326516941528428
time 0.6683420681514622
                Epoch 13


100%|██████████| 130/130 [01:35<00:00,  1.37it/s]


train_balance_acc : 0.79304 / train_avg_loss : 0.133264
[0 1] [39 91]


100%|██████████| 33/33 [00:11<00:00,  2.83it/s]


val_balance_acc : 0.8775 / val_avg_loss : 0.098711
[0 1] [ 8 25]
0.09871128980409015 0.10326516941528428
Improve avg loss :
checkpoint saved to: ./model13_finetune.pt
time 0.6570844065192287
                Epoch 14


100%|██████████| 130/130 [01:32<00:00,  1.40it/s]


train_balance_acc : 0.765152 / train_avg_loss : 0.136737
[0 1] [42 88]


100%|██████████| 33/33 [00:13<00:00,  2.52it/s]


val_balance_acc : 0.8775 / val_avg_loss : 0.093213
[0 1] [ 8 25]
0.09321277832725283 0.09871128980409015
Improve avg loss :
checkpoint saved to: ./model14_finetune.pt
time 0.6529751555319944
                Epoch 15


100%|██████████| 130/130 [01:36<00:00,  1.34it/s]


train_balance_acc : 0.791366 / train_avg_loss : 0.122195
[0 1] [43 87]


100%|██████████| 33/33 [00:11<00:00,  2.81it/s]


val_balance_acc : 0.821154 / val_avg_loss : 0.099593
[0 1] [13 20]
0.09959263756701892 0.09321277832725283
time 0.6668081254315522
                Epoch 16


100%|██████████| 130/130 [01:32<00:00,  1.41it/s]


train_balance_acc : 0.786643 / train_avg_loss : 0.123956
[0 1] [36 94]


100%|██████████| 33/33 [00:11<00:00,  2.82it/s]


val_balance_acc : 0.886364 / val_avg_loss : 0.087496
[0 1] [11 22]
0.08749590593982827 0.09321277832725283
Improve avg loss :
checkpoint saved to: ./model16_finetune.pt
time 0.6405976740129155
                Epoch 17


100%|██████████| 130/130 [01:32<00:00,  1.41it/s]


train_balance_acc : 0.803204 / train_avg_loss : 0.120685
[0 1] [38 92]


100%|██████████| 33/33 [00:11<00:00,  2.78it/s]


val_balance_acc : 0.821154 / val_avg_loss : 0.094762
[0 1] [13 20]
0.09476168512959372 0.08749590593982827
time 0.6383592906905098
                Epoch 18


100%|██████████| 130/130 [01:30<00:00,  1.44it/s]


train_balance_acc : 0.789522 / train_avg_loss : 0.11643
[0 1] [34 96]


100%|██████████| 33/33 [00:11<00:00,  2.80it/s]


val_balance_acc : 0.821154 / val_avg_loss : 0.095105
[0 1] [13 20]
0.09510490696199915 0.08749590593982827
time 0.6275821858388515
                Epoch 19


100%|██████████| 130/130 [01:33<00:00,  1.38it/s]


train_balance_acc : 0.808741 / train_avg_loss : 0.118781
[0 1] [43 87]


100%|██████████| 33/33 [00:11<00:00,  2.77it/s]


val_balance_acc : 0.85119 / val_avg_loss : 0.087469
[0 1] [12 21]
0.08746904189783064 0.08749590593982827
Improve avg loss :
checkpoint saved to: ./model19_finetune.pt
time 0.6508636723266789


In [12]:
if not cfg["no_wandb"]:
    model_artifact = wandb.Artifact(
        "model" + str(uuid.uuid1()).replace("-", ""), type="model"
    )
    model_artifact.add_file(save_path_finetune)
    wandb.log_artifact(model_artifact)

    description_artifact = wandb.Artifact(
        "description_model" + str(uuid.uuid1()).replace("-", ""), type="python"
    )

    !cp -r $path_working/dlmi/src/* $path_working/
    description_artifact.add_file(f"{path_working}/model.py")
    description_artifact.add_file(f"{path_working}/utils.py")
    description_artifact.add_file(f"{path_working}/data.py")
    wandb.log_artifact(description_artifact)

# Prediction

In [13]:
dataloader_test = data_factory(
    cfg,
    mode="test",
    split_indexes=test_index,
    path_root=path_root,
    shuffle=False,
    drop_last=False,
    transform=transform_val,
)

In [14]:
test_pred = []
test_ID = []
map_results = {
    "Id": [],
    "Predicted": [],
}

map_results_logit = {
    "Id": [],
    "logit_0": [],
    "logit_1": [],
}

print("Load model", save_path_finetune)
model.load_state_dict(torch.load(save_path_finetune)["model_state_dict"])
model.eval()
for x, annotation in tqdm(dataloader_test):
    # forward
    x = x.to(cfg["device_1"])
    with torch.no_grad():
        x = x.to(cfg["device_1"]).squeeze(0)

        if cfg["with_tab"]:
            # define tabular data
            lymph_count, age, bin_gender = (
                annotation["LYMPH_COUNT"],
                annotation["AGE"],
                annotation["BIN_GENDER"],
            )
            x_tab = torch.zeros((1, 4)).to(cfg["device_1"])
            x_tab[0, int(bin_gender)] = 1
            x_tab[0, 2] = age
            x_tab[0, 3] = lymph_count

            x = model(x, x_tab, "val")
        else:
            x = model(x, "val")

        logit = soft_max(x.unsqueeze(0))
        pred = torch.argmax(logit, dim=1)

        map_results["Predicted"].extend(pred.detach().cpu().tolist())
        map_results["Id"].extend(annotation["ID"])

        map_results_logit["logit_0"].append(logit[0][0].item())
        map_results_logit["logit_1"].append(logit[0][1].item())
        map_results_logit["Id"].extend(annotation["ID"])

Load model ./model19_finetune.pt


100%|██████████| 42/42 [00:26<00:00,  1.56it/s]


Save in Wandb !

In [15]:
df_results = pd.DataFrame(map_results)
df_results.to_csv("submission.csv", index=False)

df_results = pd.DataFrame(map_results_logit)
df_results.to_csv("logit.csv", index=False)

if not cfg["no_wandb"]:
    # log index
    df_train_index = pd.DataFrame({"train": train_index})
    df_train_index.to_csv("train_index.csv", index=False)
    df_val_index = pd.DataFrame({"val": val_index})
    df_val_index.to_csv("val_index.csv", index=False)

    csv_artifact = wandb.Artifact(
        "submission" + str(uuid.uuid1()).replace("-", ""), type="csv"
    )
    csv_artifact.add_file("submission.csv")
    csv_artifact.add_file("logit.csv")
    csv_artifact.add_file("train_index.csv")
    csv_artifact.add_file("val_index.csv")
    wandb.log_artifact(csv_artifact)

    wandb.finish()

VBox(children=(Label(value='185.151 MB of 185.151 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
balance_acc/train,▆▆▆▆▁▄▅▆▆▇▇▇▆█▇█████
balance_acc/val,▁▁▁▁▁▇▁▁▇▁██▁▇▇▅█▅▅▆
count_train_0,▁▁▁▁▁▃▆▆▆▆▇▇▇▇██▇▇▇█
count_train_1,█████▆▃▃▃▃▂▂▂▂▁▁▂▂▂▁
count_val_0,▁▁▁▁▁▂▁▁▅▁▃▇▁▅▅█▇██▇
count_val_1,█████▇██▄█▆▂█▄▄▁▂▁▁▂
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
loss/train,█████▇▇▆▆▆▅▄▃▃▃▂▂▁▁▁
loss/val,███▇▇▆▆▆▄▆▃▃▆▂▂▂▁▂▂▁
time,█▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▁▂

0,1
balance_acc/train,0.80874
balance_acc/val,0.85119
count_train_0,43.0
count_train_1,87.0
count_val_0,12.0
count_val_1,21.0
epoch,19.0
loss/train,0.11878
loss/val,0.08747
time,0.65087
