In [None]:
import torch
from torch import Tensor
from torch import optim
from torch.utils.data import DataLoader
from prompt_based.prompt import PromptModel
from utils.MetricsHistory import MetricsHistory
from utils.weighted_loss import WeightedDiceNLLLoss
from utils.utils import calculate_class_weights
from utils.dataset import promptDataset, diff_size_collate
from utils.training import start_prompt

EVAL_IGNORE_INDEX = 3
TRAIN_IGNORE_INDEX = None
NUM_CLASSES = 4
MODEL_NAME = "tmp.pytorch"
MODEL_SAVE_DIR = "tmp"
LOAD = False
SAVE = False
EPOCHS = 100
WEIGHT_DECAY = 0.01
PRETRAINED_MODEL_NAME = "openai/clip-vit-base-patch16"
TARGET_SIZE = 224
SKIP_LAYER_INDICES = [3, 5, 7, 9]
CLIP_PATH="/content/drive/MyDrive/clip/runs/clip_256_ce_dice_full_weight_fix_train_eval.pytorch"

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# Class Weights
class_weight = Tensor([0.33265044664009075, 1.669423957743164, 1.9979255956167454, 0.0])
class_weight = Tensor([0.30711034803008996, 1.5412496145750956, 1.8445296893647247, 0.30711034803008996])
class_weight = Tensor([0.2046795970925636, 1.0271954434416883, 1.2293222812780409, 1.5388026781877073])
class_weight = Tensor([1, 1, 1, 1])
# class_weight = calculate_class_weights_v3(training_data, 4, None, "dataset")
class_weight = class_weight.to(device)

target_batch_size = 64
batch_size = 2

training_data = promptDataset("datasets/pstrain/color", "datasets/pstrain/point_prompt", "datasets/pstrain/label")
val_data = promptDataset("datasets/psVal/color", "datasets/psVal/point_prompt", "datasets/psVal/label")
test_data = promptDataset("datasets/psTest/color", "datasets/psTest/point_prompt", "datasets/psTest/label")

train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, pin_memory=True, collate_fn=diff_size_collate)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, pin_memory=True, collate_fn=diff_size_collate)

accumulation_steps = target_batch_size // batch_size

# Model
# model = PromptModel(CLIP_PATH).to(device) # oad pretrained clip
model = PromptModel().to(device)

# Loses
stable_log = lambda x: torch.log(x + 1e-9)
train_loss_fn = WeightedDiceNLLLoss(ignore_index=TRAIN_IGNORE_INDEX, smooth_dice=1, class_weights=class_weight, apply_softmax=False, nll_nonlin=stable_log)
val_loss_fn = WeightedDiceNLLLoss(ignore_index=EVAL_IGNORE_INDEX, class_weights=class_weight, apply_softmax=False, nll_nonlin=stable_log)

# Optimizer
optimizer = optim.AdamW(model.parameters(), weight_decay=WEIGHT_DECAY)

# Scheduler
scheduler = None

# Metric History
agg = MetricsHistory(NUM_CLASSES, EVAL_IGNORE_INDEX)

# Training Pipiline
start_prompt(
    model_save_dir=MODEL_SAVE_DIR,
    model_save_name=MODEL_NAME,
    model=model,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    accumulation_steps=accumulation_steps,
    device=device,
    train_loss_fn=train_loss_fn,
    val_loss_fn=val_loss_fn,
    scheduler=scheduler,
    agg=agg,
    load=LOAD,
    save=SAVE,
    num_classes=NUM_CLASSES,
    ignore_index=EVAL_IGNORE_INDEX,
    target_size=TARGET_SIZE
)