In [1]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric --no-index --find-links=/kaggle/input/pytorch-geometric-packages

Looking in links: /kaggle/input/pytorch-geometric-packages
Processing /kaggle/input/pytorch-geometric-packages/torch_scatter-2.1.2-cp310-cp310-linux_x86_64.whl
Processing /kaggle/input/pytorch-geometric-packages/torch_sparse-0.6.18-cp310-cp310-linux_x86_64.whl
Processing /kaggle/input/pytorch-geometric-packages/torch_cluster-1.6.3-cp310-cp310-linux_x86_64.whl
Processing /kaggle/input/pytorch-geometric-packages/torch_spline_conv-1.2.2-cp310-cp310-linux_x86_64.whl
Processing /kaggle/input/pytorch-geometric-packages/torch_geometric-2.5.3-py3-none-any.whl
Installing collected packages: torch-spline-conv, torch-scatter, torch-sparse, torch-cluster, torch-geometric
Successfully installed torch-cluster-1.6.3 torch-geometric-2.5.3 torch-scatter-2.1.2 torch-sparse-0.6.18 torch-spline-conv-1.2.2


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, GCNConv, LayerNorm, GraphNorm, InstanceNorm
from torch_geometric.nn.models import GAT, GraphSAGE, PMLP
from torch_geometric.nn.pool import max_pool
from torch_geometric.nn.norm import BatchNorm
from torch_geometric.utils import to_dense_adj
from torch_geometric.utils import dropout_edge
from scipy.spatial import distance_matrix
import networkx as nx
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.model_selection import GroupKFold, StratifiedGroupKFold
from torchvision.models.feature_extraction import create_feature_extractor
from torch_geometric.nn.pool import global_max_pool

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
import random

## Config

In [3]:
batch_size = 10
num_epochs = 50
kfold = 5
alpha = 0.25
gamma = 2
beta = 0.1

## Modeling

### Prepare data

In [4]:
df = pd.read_csv('/kaggle/input/isic-2024-challenge/train-metadata.csv')

In [5]:
num_positive = df[df.target == 1].shape[0]
num_negative = df[df.target == 0].shape[0]
total = num_positive + num_negative
class_weight = torch.tensor([total/num_negative, total/num_positive]).cuda()

In [6]:
#df_eff = pd.read_csv("/kaggle/input/skin-cancer-image-model-output/output_eff.csv")
#df_eff = pd.concat([df.copy(), df_eff], axis=1)
#df_eff = df_eff[['patient_id', "target_effnetv1b0"]]
#df_eva = pd.read_csv("/kaggle/input/skin-cancer-image-model-output/output_eva.csv")
#df_eva = pd.concat([df.copy(), df_eva], axis=1)
#df_eva = df_eva[['patient_id', "target_eva02"]]

In [7]:
data = []
label_img = []
label_node = []
data_size = []
for id_ in tqdm(df.patient_id.unique()):
    data_i = df[df.patient_id == id_]
    label_node.append(np.array(data_i.target))
    label_img.append(data_i.target.max())
    data.append(data_i)
    data_size.append(data_i.shape[0])

100%|██████████| 1042/1042 [01:11<00:00, 14.54it/s]


In [8]:
data_list = []
for i, data_ in tqdm(enumerate(data)):
    id_ = data_.patient_id.iloc[0]
    node_label = torch.tensor(np.array(data_.target))
    #node_label = torch.tensor(np.array(data_.target.max()))
    #node_feature_eff = np.array(df_eff[df.patient_id == id_]['target_effnetv1b0'])[:, None].astype('float32')
    #node_feature_eva = np.array(df_eva[df.patient_id == id_]['target_eva02'])[:, None].astype('float32')
    #node_feature_meta = torch.tensor(np.load(f'/kaggle/input/skin-cancer-graph-efficientnet/node_features/{id_}_node_feature_meta.npy').astype('float32'))
    #node_feature = torch.tensor(np.concatenate([node_feature_eff, node_feature_eva], axis=-1))
    node_feature = torch.tensor(np.load(f'/kaggle/input/skin-cancer-graph-efficientnet/node_features/{id_}_node_feature_img.npy').astype('float32'))
    edge_index = torch.tensor(np.load(f'/kaggle/input/img-patient-wise-structural-relation-of-skin-c/pos/{id_}_edge_index.npy'))
    pos = torch.tensor(np.load(f'/kaggle/input/img-patient-wise-structural-relation-of-skin-c/edge_index/{id_}_pos.npy').astype('float32'))
    graph = Data(x=node_feature, y = node_label, pos=pos, edge_index=edge_index)
    data_list.append(graph)

1042it [00:16, 62.41it/s]


### Build Model

In [9]:
class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='sum'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        CE_loss = F.cross_entropy(inputs, targets, weight = class_weight, reduction='none')
        pt = torch.exp(-CE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * CE_loss

        if self.reduction == 'mean':
            return F_loss.mean()
        elif self.reduction == 'sum':
            return F_loss.sum()/batch_size
        else:
            return F_loss
        
        
def graph_laplacian_loss(node_embeddings, batch):
    edge_index = batch.edge_index.cuda()
    batch = batch.batch.cuda()
    if edge_index.shape[0]==0:
        return torch.tensor(0.)
    adjacency_matrix = to_dense_adj(edge_index, batch=batch)
    degree_matrix = torch.diag(adjacency_matrix.sum(dim=1))
    laplacian_matrix = degree_matrix - adjacency_matrix
    smoothness_loss = torch.trace(torch.matmul(node_embeddings.T, torch.matmul(laplacian_matrix, node_embeddings)))
    return smoothness_loss


def augmentation(edge_index, p=0.4):
    edge_index, edge_attr = dropout_edge(edge_index, p=p)
    return edge_index


def entropy_loss(logits):
    p = logits.softmax(dim=-1)
    loss = torch.sum(- p *p.log())/batch_size
    return loss

In [10]:
class GNN_Model(nn.Module):
    def __init__(self, in_channels = 128, hidden_channels = 128, out_channels = 2):
        super().__init__()
        self.in_channels = in_channels
        self.graphsage = GraphSAGE(in_channels, hidden_channels,
                    6, out_channels, jk='lstm', norm = LayerNorm(hidden_channels),
                    dropout = 0.3,  normalize=True)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        if edge_index.shape[0]==0:
            edge_index = torch.tensor([[0, 0]]).T.cuda()
        x = x.cuda()
        edge_index = edge_index.cuda()
        batch = batch.cuda()
        logits = self.graphsage(x, edge_index, batch=batch)
        return logits
        
    
    def predict(self, batch):
        with torch.no_grad():
            logits = self.forward(batch)
            pred = logits.softmax(dim=-1)[:, 1]
        return pred
    
    def fit(self, train_loader, optimizer, fold):
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
        best_score = 0.
        count=0
        for epoch in range(num_epochs):
            loop = tqdm(train_loader)
            self.train()
            for batch in loop:
                y = batch.y.cuda()
                logits = self.forward(batch)
                l_cls = FocalLoss(alpha = alpha, gamma=gamma)(logits, y)
                l_ent = entropy_loss(logits)
                loss = l_cls + l_ent
                #loss = nn.CrossEntropyLoss(weight=class_weight)(logits, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                torch.cuda.empty_cache()
                loop.set_postfix(loss=f'epoch: {epoch}--loss: {loss} --l_cls: {l_cls} -- l_ent: {l_ent}')
            scheduler.step()
            if epoch%1==0:
                self.eval()
                labels = []
                preds = []
                for batch in valid_loader:
                    label = batch.y
                    pred = self.predict(batch)
                    preds.append(pred)
                    labels.append(label)
                    torch.cuda.empty_cache()
                preds = torch.cat(preds).cpu().numpy()
                labels = torch.cat(labels).cpu().numpy()
                score = comp_score(pd.DataFrame(labels, columns = ['target']),
                                   pd.DataFrame(preds, columns=["prediction"]), "")
                
                if score>best_score:
                    count=0
                    best_score = score
                    torch.save(self.state_dict(), f'/kaggle/working/model_{fold}.pt')
                else:
                    count+=1
                    if count==6:
                        break
                print(f"epoch: {epoch} - Partial AUC Score: {score:.5f}")

### Kfold Validation

In [11]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

In [12]:
def comp_score(solution: pd.DataFrame, submission: pd.DataFrame, row_id_column_name: str, min_tpr: float=0.80):
    v_gt = abs(np.asarray(solution.values)-1)
    v_pred = np.array([1.0 - x for x in submission.values])
    max_fpr = abs(1-min_tpr)
    partial_auc_scaled = roc_auc_score(v_gt, v_pred, max_fpr=max_fpr)
    # change scale from [0.5, 1.0] to [0.5 * max_fpr**2, max_fpr]
    # https://math.stackexchange.com/questions/914823/shift-numbers-into-a-different-range
    partial_auc = 0.5 * max_fpr**2 + (max_fpr - 0.5 * max_fpr**2) / (1.0 - 0.5) * (partial_auc_scaled - 0.5)
    return partial_auc

In [13]:
set_seed()

In [14]:
df_train = df.copy()

In [15]:
gkf = GroupKFold(n_splits=kfold)
df_train["fold"] = -1
for idx, (train_idx, val_idx) in enumerate(gkf.split(df_train, df_train["target"], groups=df_train["patient_id"])):
    df_train.loc[val_idx, "fold"] = idx

In [16]:
models = []
for fold in range(5):
    _df_train = df_train[df_train["fold"] != fold].reset_index(drop=True)
    _df_valid = df_train[df_train["fold"] == fold].reset_index(drop=True)
    #train
    mask = np.isin(df.patient_id.unique(), _df_train.patient_id.unique())
    idx = np.where(mask)[0]
    data_train_list = list(map(lambda i: data_list[i] ,idx))
    #val
    mask = np.isin(df.patient_id.unique(), _df_valid.patient_id.unique())
    idx = np.where(mask)[0]
    data_valid_list = list(map(lambda i: data_list[i] ,idx))

    #loader
    train_loader = DataLoader(data_train_list, batch_size=batch_size, shuffle=True, 
                              num_workers = 4, persistent_workers=True)
    valid_loader = DataLoader(data_valid_list, batch_size=batch_size, shuffle=False,
                              num_workers = 4, persistent_workers=True)
    
    #modeling
    model = GNN_Model().cuda()
    optimizer = optim.AdamW(model.parameters())
    model.fit(train_loader, optimizer, fold)
    models.append(model)
    torch.cuda.empty_cache()

100%|██████████| 84/84 [00:13<00:00,  6.24it/s, loss=epoch: 0--loss: 39.478397369384766 --l_cls: 3.4125454425811768 -- l_ent: 36.065853118896484]


epoch: 0 - Partial AUC Score: 0.10634


100%|██████████| 84/84 [00:13<00:00,  6.44it/s, loss=epoch: 1--loss: 23.951841354370117 --l_cls: 3.0293314456939697 -- l_ent: 20.922510147094727]


epoch: 1 - Partial AUC Score: 0.11973


100%|██████████| 84/84 [00:12<00:00,  6.67it/s, loss=epoch: 2--loss: 120.65739440917969 --l_cls: 95.93675994873047 -- l_ent: 24.720630645751953]


epoch: 2 - Partial AUC Score: 0.12538


100%|██████████| 84/84 [00:12<00:00,  6.80it/s, loss=epoch: 3--loss: 65.28377532958984 --l_cls: 42.63838577270508 -- l_ent: 22.6453914642334]


epoch: 3 - Partial AUC Score: 0.13032


100%|██████████| 84/84 [00:12<00:00,  6.92it/s, loss=epoch: 4--loss: 50.40668487548828 --l_cls: 23.498464584350586 -- l_ent: 26.908222198486328]


epoch: 4 - Partial AUC Score: 0.13218


100%|██████████| 84/84 [00:12<00:00,  6.76it/s, loss=epoch: 5--loss: 35.29804611206055 --l_cls: 10.167963981628418 -- l_ent: 25.130081176757812]


epoch: 5 - Partial AUC Score: 0.13301


100%|██████████| 84/84 [00:12<00:00,  6.73it/s, loss=epoch: 6--loss: 21.520263671875 --l_cls: 8.16928768157959 -- l_ent: 13.350976943969727]


epoch: 6 - Partial AUC Score: 0.13307


100%|██████████| 84/84 [00:12<00:00,  6.69it/s, loss=epoch: 7--loss: 29.197547912597656 --l_cls: 15.160639762878418 -- l_ent: 14.036908149719238]


epoch: 7 - Partial AUC Score: 0.13422


100%|██████████| 84/84 [00:12<00:00,  6.68it/s, loss=epoch: 8--loss: 26.17302703857422 --l_cls: 10.164652824401855 -- l_ent: 16.00837516784668]


epoch: 8 - Partial AUC Score: 0.13352


100%|██████████| 84/84 [00:13<00:00,  6.44it/s, loss=epoch: 9--loss: 73.21129608154297 --l_cls: 52.05058670043945 -- l_ent: 21.16071128845215]


epoch: 9 - Partial AUC Score: 0.13467


100%|██████████| 84/84 [00:12<00:00,  6.54it/s, loss=epoch: 10--loss: 25.77043914794922 --l_cls: 19.068628311157227 -- l_ent: 6.701809883117676]


epoch: 10 - Partial AUC Score: 0.13372


100%|██████████| 84/84 [00:12<00:00,  6.49it/s, loss=epoch: 11--loss: 18.311609268188477 --l_cls: 5.217589855194092 -- l_ent: 13.094018936157227]


epoch: 11 - Partial AUC Score: 0.13384


100%|██████████| 84/84 [00:12<00:00,  6.53it/s, loss=epoch: 12--loss: 33.6170539855957 --l_cls: 11.485278129577637 -- l_ent: 22.131776809692383]


epoch: 12 - Partial AUC Score: 0.13396


100%|██████████| 84/84 [00:12<00:00,  6.55it/s, loss=epoch: 13--loss: 15.054200172424316 --l_cls: 11.670591354370117 -- l_ent: 3.38360857963562]


epoch: 13 - Partial AUC Score: 0.13151


100%|██████████| 84/84 [00:12<00:00,  6.56it/s, loss=epoch: 14--loss: 34.87299728393555 --l_cls: 24.009511947631836 -- l_ent: 10.863486289978027]


epoch: 14 - Partial AUC Score: 0.13373


100%|██████████| 84/84 [00:12<00:00,  6.59it/s, loss=epoch: 15--loss: 26.883813858032227 --l_cls: 20.845001220703125 -- l_ent: 6.038812160491943]
100%|██████████| 84/84 [00:13<00:00,  6.01it/s, loss=epoch: 0--loss: 109.8994140625 --l_cls: 45.13736343383789 -- l_ent: 64.76204681396484]


epoch: 0 - Partial AUC Score: 0.10632


100%|██████████| 84/84 [00:13<00:00,  6.21it/s, loss=epoch: 1--loss: 15.111927032470703 --l_cls: 9.069674491882324 -- l_ent: 6.042252540588379]


epoch: 1 - Partial AUC Score: 0.10688


100%|██████████| 84/84 [00:14<00:00,  5.93it/s, loss=epoch: 2--loss: 15.122093200683594 --l_cls: 3.909868001937866 -- l_ent: 11.212224960327148]


epoch: 2 - Partial AUC Score: 0.11105


100%|██████████| 84/84 [00:13<00:00,  6.10it/s, loss=epoch: 3--loss: 16.51709747314453 --l_cls: 4.51080322265625 -- l_ent: 12.006293296813965]


epoch: 3 - Partial AUC Score: 0.12124


100%|██████████| 84/84 [00:13<00:00,  6.26it/s, loss=epoch: 4--loss: 24.142459869384766 --l_cls: 11.122983932495117 -- l_ent: 13.019476890563965]


epoch: 4 - Partial AUC Score: 0.12389


100%|██████████| 84/84 [00:13<00:00,  6.18it/s, loss=epoch: 5--loss: 23.05474853515625 --l_cls: 11.909868240356445 -- l_ent: 11.144879341125488]


epoch: 5 - Partial AUC Score: 0.12053


100%|██████████| 84/84 [00:13<00:00,  6.15it/s, loss=epoch: 6--loss: 21.260967254638672 --l_cls: 11.621567726135254 -- l_ent: 9.639398574829102]


epoch: 6 - Partial AUC Score: 0.12006


100%|██████████| 84/84 [00:13<00:00,  6.18it/s, loss=epoch: 7--loss: 7.665403366088867 --l_cls: 3.5777080059051514 -- l_ent: 4.087695121765137]


epoch: 7 - Partial AUC Score: 0.12043


100%|██████████| 84/84 [00:13<00:00,  6.19it/s, loss=epoch: 8--loss: 10.027875900268555 --l_cls: 3.842390537261963 -- l_ent: 6.185485363006592]


epoch: 8 - Partial AUC Score: 0.12168


100%|██████████| 84/84 [00:13<00:00,  6.22it/s, loss=epoch: 9--loss: 7.109288215637207 --l_cls: 4.524547576904297 -- l_ent: 2.5847408771514893]


epoch: 9 - Partial AUC Score: 0.11948


100%|██████████| 84/84 [00:13<00:00,  6.06it/s, loss=epoch: 10--loss: 9.856277465820312 --l_cls: 5.227377891540527 -- l_ent: 4.628900051116943]
100%|██████████| 84/84 [00:14<00:00,  5.97it/s, loss=epoch: 0--loss: 220.33798217773438 --l_cls: 195.38624572753906 -- l_ent: 24.95173454284668]


epoch: 0 - Partial AUC Score: 0.10904


100%|██████████| 84/84 [00:13<00:00,  6.30it/s, loss=epoch: 1--loss: 12.74338436126709 --l_cls: 3.1369731426239014 -- l_ent: 9.60641098022461]


epoch: 1 - Partial AUC Score: 0.14276


100%|██████████| 84/84 [00:13<00:00,  6.05it/s, loss=epoch: 2--loss: 23.672698974609375 --l_cls: 8.492779731750488 -- l_ent: 15.179919242858887]


epoch: 2 - Partial AUC Score: 0.15023


100%|██████████| 84/84 [00:13<00:00,  6.24it/s, loss=epoch: 3--loss: 121.91082000732422 --l_cls: 108.06587219238281 -- l_ent: 13.844949722290039]


epoch: 3 - Partial AUC Score: 0.13118


100%|██████████| 84/84 [00:13<00:00,  6.12it/s, loss=epoch: 4--loss: 19.973764419555664 --l_cls: 7.204634189605713 -- l_ent: 12.76913070678711]


epoch: 4 - Partial AUC Score: 0.14456


100%|██████████| 84/84 [00:13<00:00,  6.07it/s, loss=epoch: 5--loss: 24.28902816772461 --l_cls: 14.041964530944824 -- l_ent: 10.247062683105469]


epoch: 5 - Partial AUC Score: 0.14718


100%|██████████| 84/84 [00:13<00:00,  6.05it/s, loss=epoch: 6--loss: 3.4501209259033203 --l_cls: 1.2402344942092896 -- l_ent: 2.209886312484741]


epoch: 6 - Partial AUC Score: 0.13902


100%|██████████| 84/84 [00:13<00:00,  6.05it/s, loss=epoch: 7--loss: 7.631477355957031 --l_cls: 2.623373508453369 -- l_ent: 5.008103847503662]


epoch: 7 - Partial AUC Score: 0.14059


100%|██████████| 84/84 [00:13<00:00,  6.01it/s, loss=epoch: 8--loss: 25.05340576171875 --l_cls: 18.52168083190918 -- l_ent: 6.5317254066467285]
100%|██████████| 84/84 [00:13<00:00,  6.01it/s, loss=epoch: 0--loss: 182.49575805664062 --l_cls: 77.8823013305664 -- l_ent: 104.61345672607422]


epoch: 0 - Partial AUC Score: 0.11742


100%|██████████| 84/84 [00:13<00:00,  6.08it/s, loss=epoch: 1--loss: 28.16489601135254 --l_cls: 9.461755752563477 -- l_ent: 18.703140258789062]


epoch: 1 - Partial AUC Score: 0.11495


100%|██████████| 84/84 [00:13<00:00,  6.04it/s, loss=epoch: 2--loss: 274.47021484375 --l_cls: 256.9034423828125 -- l_ent: 17.5667724609375]


epoch: 2 - Partial AUC Score: 0.12324


100%|██████████| 84/84 [00:13<00:00,  6.14it/s, loss=epoch: 3--loss: 7.601358413696289 --l_cls: 2.204042434692383 -- l_ent: 5.397315979003906]


epoch: 3 - Partial AUC Score: 0.12492


100%|██████████| 84/84 [00:14<00:00,  5.99it/s, loss=epoch: 4--loss: 62.07304763793945 --l_cls: 49.97394943237305 -- l_ent: 12.099098205566406]


epoch: 4 - Partial AUC Score: 0.12552


100%|██████████| 84/84 [00:13<00:00,  6.06it/s, loss=epoch: 5--loss: 4.543774604797363 --l_cls: 0.9608314633369446 -- l_ent: 3.5829429626464844]


epoch: 5 - Partial AUC Score: 0.12776


100%|██████████| 84/84 [00:13<00:00,  6.02it/s, loss=epoch: 6--loss: 16.92107391357422 --l_cls: 9.520340919494629 -- l_ent: 7.400732517242432]


epoch: 6 - Partial AUC Score: 0.13002


100%|██████████| 84/84 [00:13<00:00,  6.02it/s, loss=epoch: 7--loss: 79.4237289428711 --l_cls: 67.6695785522461 -- l_ent: 11.754151344299316]


epoch: 7 - Partial AUC Score: 0.12916


100%|██████████| 84/84 [00:13<00:00,  6.12it/s, loss=epoch: 8--loss: 11.922637939453125 --l_cls: 8.164815902709961 -- l_ent: 3.757821798324585]


epoch: 8 - Partial AUC Score: 0.13497


100%|██████████| 84/84 [00:13<00:00,  6.17it/s, loss=epoch: 9--loss: 4.1697797775268555 --l_cls: 0.982780933380127 -- l_ent: 3.1869986057281494]


epoch: 9 - Partial AUC Score: 0.13323


100%|██████████| 84/84 [00:13<00:00,  6.15it/s, loss=epoch: 10--loss: 5.954334259033203 --l_cls: 1.871267557144165 -- l_ent: 4.083066463470459]


epoch: 10 - Partial AUC Score: 0.13303


100%|██████████| 84/84 [00:13<00:00,  6.14it/s, loss=epoch: 11--loss: 10.546645164489746 --l_cls: 3.40971302986145 -- l_ent: 7.136932373046875]


epoch: 11 - Partial AUC Score: 0.13175


100%|██████████| 84/84 [00:13<00:00,  6.09it/s, loss=epoch: 12--loss: 9.722038269042969 --l_cls: 3.112870693206787 -- l_ent: 6.609167575836182]


epoch: 12 - Partial AUC Score: 0.13236


100%|██████████| 84/84 [00:13<00:00,  6.10it/s, loss=epoch: 13--loss: 4.790816307067871 --l_cls: 1.1299736499786377 -- l_ent: 3.6608428955078125]


epoch: 13 - Partial AUC Score: 0.13459


100%|██████████| 84/84 [00:13<00:00,  6.14it/s, loss=epoch: 14--loss: 9.540779113769531 --l_cls: 1.9783210754394531 -- l_ent: 7.562458038330078]


epoch: 14 - Partial AUC Score: 0.13560


100%|██████████| 84/84 [00:13<00:00,  6.16it/s, loss=epoch: 15--loss: 3.057746410369873 --l_cls: 0.9659927487373352 -- l_ent: 2.0917537212371826]


epoch: 15 - Partial AUC Score: 0.13519


100%|██████████| 84/84 [00:13<00:00,  6.07it/s, loss=epoch: 16--loss: 2.9026012420654297 --l_cls: 1.2180591821670532 -- l_ent: 1.6845420598983765]


epoch: 16 - Partial AUC Score: 0.13411


100%|██████████| 84/84 [00:13<00:00,  6.16it/s, loss=epoch: 17--loss: 0.848926305770874 --l_cls: 0.18854117393493652 -- l_ent: 0.6603851318359375]


epoch: 17 - Partial AUC Score: 0.13484


100%|██████████| 84/84 [00:13<00:00,  6.09it/s, loss=epoch: 18--loss: 11.05074691772461 --l_cls: 4.419528484344482 -- l_ent: 6.631217956542969]


epoch: 18 - Partial AUC Score: 0.13383


100%|██████████| 84/84 [00:13<00:00,  6.05it/s, loss=epoch: 19--loss: 11.018938064575195 --l_cls: 4.186402320861816 -- l_ent: 6.832535743713379]


epoch: 19 - Partial AUC Score: 0.13400


100%|██████████| 84/84 [00:13<00:00,  6.10it/s, loss=epoch: 20--loss: 2.8663811683654785 --l_cls: 0.7423807382583618 -- l_ent: 2.1240005493164062]
100%|██████████| 84/84 [00:14<00:00,  5.99it/s, loss=epoch: 0--loss: 28.83861541748047 --l_cls: 6.976803779602051 -- l_ent: 21.8618106842041]


epoch: 0 - Partial AUC Score: 0.08510


100%|██████████| 84/84 [00:14<00:00,  5.96it/s, loss=epoch: 1--loss: 18.04572105407715 --l_cls: 4.539994239807129 -- l_ent: 13.50572681427002]


epoch: 1 - Partial AUC Score: 0.09571


100%|██████████| 84/84 [00:13<00:00,  6.12it/s, loss=epoch: 2--loss: 12.314702033996582 --l_cls: 5.548490047454834 -- l_ent: 6.766211986541748]


epoch: 2 - Partial AUC Score: 0.08238


100%|██████████| 84/84 [00:13<00:00,  6.05it/s, loss=epoch: 3--loss: 10.912632942199707 --l_cls: 2.6808178424835205 -- l_ent: 8.231815338134766]


epoch: 3 - Partial AUC Score: 0.09913


100%|██████████| 84/84 [00:13<00:00,  6.04it/s, loss=epoch: 4--loss: 5.556187629699707 --l_cls: 2.162412166595459 -- l_ent: 3.393775701522827]


epoch: 4 - Partial AUC Score: 0.09828


100%|██████████| 84/84 [00:13<00:00,  6.06it/s, loss=epoch: 5--loss: 1.9879988431930542 --l_cls: 0.9694887399673462 -- l_ent: 1.018510103225708]


epoch: 5 - Partial AUC Score: 0.08678


100%|██████████| 84/84 [00:14<00:00,  5.98it/s, loss=epoch: 6--loss: 8.21368408203125 --l_cls: 4.198449611663818 -- l_ent: 4.015234470367432]


epoch: 6 - Partial AUC Score: 0.07319


100%|██████████| 84/84 [00:13<00:00,  6.06it/s, loss=epoch: 7--loss: 1.781416654586792 --l_cls: 0.7897208333015442 -- l_ent: 0.9916958212852478]


epoch: 7 - Partial AUC Score: 0.07510


100%|██████████| 84/84 [00:14<00:00,  5.99it/s, loss=epoch: 8--loss: 7.570220947265625 --l_cls: 3.073209047317505 -- l_ent: 4.497011661529541]


epoch: 8 - Partial AUC Score: 0.08106


100%|██████████| 84/84 [00:13<00:00,  6.00it/s, loss=epoch: 9--loss: 18.415102005004883 --l_cls: 7.700491428375244 -- l_ent: 10.714611053466797]


In [17]:
loader = DataLoader(data_list, batch_size=batch_size, shuffle=False, 
                              num_workers = 4, persistent_workers=True)

In [18]:
patient_ids = df.patient_id.unique()
df_preds = []

for fold in range(5):
    model = models[fold]
    model.load_state_dict(torch.load(f'/kaggle/working/model_{fold}.pt'))
    model.eval()
    df_pred = pd.DataFrame(df[['isic_id', 'patient_id']], columns=['isic_id', 'patient_id', f'pred_{fold}'],
                               index=df.index)
    for i, batch in enumerate(tqdm(loader)):
        preds = model.predict(batch)
        batch_idx = batch.batch.unique()
        for idx in batch_idx:
            pred_i = preds[batch.batch == idx].cpu().numpy()
            sample_idx = idx + i * batch_size
            patient_id = patient_ids[sample_idx]
            df_pred.loc[df_pred.patient_id == patient_id, f'pred_{fold}'] = pred_i
    df_preds.append(df_pred)
df_preds = pd.concat(df_preds, axis=1)

100%|██████████| 105/105 [01:17<00:00,  1.36it/s]
100%|██████████| 105/105 [01:16<00:00,  1.36it/s]
100%|██████████| 105/105 [01:16<00:00,  1.36it/s]
100%|██████████| 105/105 [01:18<00:00,  1.35it/s]
100%|██████████| 105/105 [01:18<00:00,  1.34it/s]


In [19]:
df_preds.to_csv('/kaggle/working/gnn_preds.csv')

In [20]:
#df_preds[df.target==1].hist()