In [None]:
import torch
from torch import nn
from torch import Tensor
from torch import optim
from torch.utils.data import DataLoader
from utils.training import start
from utils.MetricsHistory import MetricsHistory
from unet.unet import unet
from utils.weighted_loss import WeightedDiceCELoss
from utils.utils import calculate_class_weights
from utils.dataset import dataset, target_remap, diff_size_collate

EVAL_IGNORE_INDEX = 3
TRAIN_IGNORE_INDEX = None
NUM_CLASSES = 4
MODEL_NAME = "tmp.pytorch"
MODEL_SAVE_DIR = "tmp"
LOAD = False
SAVE = False
EPOCHS = 100
WEIGHT_DECAY = 0.01
TARGET_SIZE = 256

# Determine the device to use (GPU if available, otherwise CPU)
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

target_batch_size = 64
batch_size = 2

# Create datasets for training, validation, and testing
training_data = dataset("datasets/astrain/color", "datasets/astrain/label", target_transform=target_remap())
val_data = dataset("datasets/Val/color", "datasets/Val/label", target_transform=target_remap())
test_data = dataset("datasets/Test/color", "datasets/Test/label", target_transform=target_remap())

# Create data loaders for training, validation, and testing
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, pin_memory=True, collate_fn=diff_size_collate)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, pin_memory=True, collate_fn=diff_size_collate)


# Class Weights
class_weight = Tensor([0.30711034803008996, 1.5412496145750956, 1.8445296893647247, 0.30711034803008996])
class_weight = Tensor([0.2046795970925636, 1.0271954434416883, 1.2293222812780409, 1.5388026781877073])
# class_weight = [1, 1, 1, 1]
# class_weight = calculate_class_weights(training_data, 4, None, "dataset")
class_weight = class_weight.to(device)
class_weight = None

# Calculate the number of accumulation steps
accumulation_steps = target_batch_size // batch_size

# Model
model = unet(3, 4).to(device)

# Losses
train_loss_fn = WeightedDiceCELoss(ignore_index=TRAIN_IGNORE_INDEX, smooth_dice=1, class_weights=class_weight)
val_loss_fn = WeightedDiceCELoss(ignore_index=EVAL_IGNORE_INDEX, class_weights=class_weight)

train_loss_fn = nn.CrossEntropyLoss(weight=class_weight)
val_loss_fn = nn.CrossEntropyLoss(weight=class_weight, ignore_index=EVAL_IGNORE_INDEX)

train_loss_fn = nn.CrossEntropyLoss()
val_loss_fn = nn.CrossEntropyLoss(ignore_index=EVAL_IGNORE_INDEX)

# Optimizer
optimizer = optim.AdamW(model.parameters(), weight_decay=WEIGHT_DECAY)

# Scheduler
scheduler = None

# Metric History
agg = MetricsHistory(NUM_CLASSES, EVAL_IGNORE_INDEX)

# Training Pipiline
start(
    model_save_dir=MODEL_SAVE_DIR,
    model_save_name=MODEL_NAME,
    model=model,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    accumulation_steps=accumulation_steps,
    device=device,
    train_loss_fn=train_loss_fn,
    val_loss_fn=val_loss_fn,
    scheduler=scheduler,
    agg=agg,
    load=LOAD,
    save=SAVE,
    num_classes=NUM_CLASSES,
    ignore_index=EVAL_IGNORE_INDEX,
    target_size=TARGET_SIZE
)