In [1]:
from prototypes.deeplearning.dataloader.IsicDataLoader import LoadDataVectors
from prototypes.deeplearning.trainner import train_single_task
import torch
import json
from prototypes.utility.data import ProjectConfiguration
import torchvision
import os

In [2]:
weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V2
model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2)

In [3]:
config = ProjectConfiguration("../config.json")

In [4]:
config.get_keys()

dict_keys(['DATASET_PATH', 'TRAIN_IMAGES_PATH', 'TRAIN_METADATA', 'TEST_METADATA', 'SAMPLE_SUBMISSION', 'IMAGE_WIDTH', 'IMAGE_HEIGHT', 'TARGET_COLUMNS', 'VECTORS_PATH', 'BATCH_SIZE', 'K_FOLDS', 'NUM_CLASSES', 'ALPHA', 'TRAIN_DEVICE', 'NUM_EPOCHS', 'NUM_WORKERS', 'SAMPLE_PERCENTAGE', 'HYPER_PARAMETERS_PATH', 'VERSION'])

In [5]:
config.get_value("TRAIN_METADATA")

'/home/matias/workspace/datasets/isic-2024-challenge/train-metadata.csv'

In [None]:
dataloader = LoadDataVectors(hd5_file_path=os.path.join(config.get_value("DATASET_PATH"), "train-image.hdf5"),
                             metadata_csv_path=config.get_value("TRAIN_METADATA"),
                             target_columns=["target"],
                             transform=weights.transforms())

In [None]:
train, val = torch.utils.data.random_split(dataloader, [0.8, 0.2])

In [None]:
len(train), len(val)

In [None]:
train_dataloader = torch.utils.data.DataLoader(train, batch_size=config.get_value("BATCH_SIZE"), shuffle=True, num_workers=config.get_value("NUM_WORKERS"))
val_dataloader = torch.utils.data.DataLoader(val, batch_size=config.get_value("BATCH_SIZE")//2, shuffle=False, num_workers=config.get_value("NUM_WORKERS"))

In [None]:
model.fc = torch.nn.Sequential(torch.nn.Linear(2048, config.get_value("NUM_CLASSES")), torch.nn.Sigmoid())
model = model.to(device=config.get_value("TRAIN_DEVICE"))

train_single_task(model=model, train_dataloader=train_dataloader, val_dataloader=val_dataloader,
                  optimizer=torch.optim.Adam(params=model.parameters(), lr=1e-4), criterion=torch.nn.BCELoss(), device=config.get_value("TRAIN_DEVICE"),
                  epochs=config.get_value("NUM_EPOCHS"), alpha=config.get_value("ALPHA"))