In [None]:
from pickle import TRUE
from tqdm import tqdm
import torch
import os
import numpy as np
from torch.utils.data import DataLoader
from torch import nn

from utils.dataset import *
from utils.utils import *
from utils.training import trainReconstruction, evalReconstruction
from autoencoder.autoencoder import ReconstructionAutoencoder


batch_size = 2
target_batch_size = 64
accumulation_steps = target_batch_size // batch_size

training_data = dataset("datasets/astrain/color", "datasets/astrain/label", target_transform=target_remap())
validation_data = dataset("datasets/Val/color", "datasets/Val/label", target_transform=target_remap())
test_data = dataset("datasets/Test/color", "datasets/Test/label", target_transform=target_remap())

train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(validation_data, batch_size=batch_size, shuffle=True, pin_memory=True, collate_fn=diff_size_collate)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, pin_memory=True,collate_fn=diff_size_collate)


if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


model = ReconstructionAutoencoder(din=3, dout=3).to(device)
loss_fn = nn.MSELoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
MODEL_SAVE_DIR = "tmp"
MODEL_NAME = "tmp.pytorch"
start_epoch = 0
recoverCheckpoint = False
TARGET_SIZE = 256

if recoverCheckpoint and os.path.isfile(f"{MODEL_SAVE_DIR}/{MODEL_NAME}"):
    print(f"Loading checkpoint from: {MODEL_SAVE_DIR}/{MODEL_NAME}")
    # Load the checkpoint dictionary; move tensors to the correct device
    checkpoint = torch.load(f"{MODEL_SAVE_DIR}/{MODEL_NAME}", map_location=device, weights_only=False)

    # Load model state
    model.load_state_dict(checkpoint["model_state_dict"])
    print(" -> Model state loaded.")

    # Load optimizer state
    try:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        print(" -> Optimizer state loaded.")
    except Exception as e:
        print(f" -> Warning: Could not load optimizer state: {e}. Optimizer will start from scratch.")

    # Load training metadata
    start_epoch = checkpoint.get("epoch", 0) # Load last completed epoch, training continues from next one
    best_val_loss = checkpoint.get("best_val_loss", np.inf)

    print(f" -> Resuming training from epoch {start_epoch + 1}")
    loaded_notes = checkpoint.get("notes", "N/A")
    print(f" -> Notes from checkpoint: {loaded_notes}")

else:
    print(f"Checkpoint file not found at {MODEL_SAVE_DIR}/{MODEL_NAME}. Starting training from scratch.")



best_val_loss = np.inf
EPOCHS = 100
print("\nStarting Training (Autoencoder)...")
for t in range(start_epoch, EPOCHS):
    current_epoch = t + 1
    print(f"Epoch {t+1}\n-------------------------------")
    train_loss = trainReconstruction(train_dataloader, model, loss_fn, optimizer, accumulation_steps)

    wrong_val_loss, correct_val_loss = evalReconstruction(val_dataloader, model, loss_fn, target_size=TARGET_SIZE)

    # Save model based on validation val loss improvement
    if correct_val_loss < best_val_loss:
        print(f"Validation loss improved ({best_val_loss:.6f} → {correct_val_loss:.6f}). Saving model...")
        best_val_loss = correct_val_loss # Save corresponding loss
        checkpoint_path = os.path.join(MODEL_SAVE_DIR, f"{MODEL_NAME}") # Changed name
        checkpoint = {
            "epoch": t + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "best_val_loss": best_val_loss,
        }
        torch.save(checkpoint, checkpoint_path)

    else:
        print(f"Corresponding validation loss: {correct_val_loss:.6f} not better than {best_val_loss}")

    #print(f"Wrong Validation loss: {wrong_val_loss:.6f}")
    print(f"Train loss: {train_loss:.6f}")

    # PLot a training image reconstruction
    img, label = training_data[0]
    img = img.to(device)
    res = model(img.unsqueeze(0))
    plt.imshow(res[0].permute(1,2,0).cpu().detach().numpy())
    plt.savefig(f"drive/MyDrive/autoencoder/images/test{t}.png", format="png")
    plt.show()

print("\n--- Training Finished! ---")
print(f"Best model saved to: {os.path.join(MODEL_SAVE_DIR, f'{MODEL_NAME}')}")

In [None]:
import torch
from torch import Tensor
from torch import optim
from torch.utils.data import DataLoader
from utils.training import start
from utils.MetricsHistory import MetricsHistory
from utils.weighted_loss import WeightedDiceCELoss
from utils.utils import calculate_class_weights
from utils.dataset import dataset, target_remap, diff_size_collate
from autoencoder.autoencoder import SegmentationAutoencoder

EVAL_IGNORE_INDEX = 3
TRAIN_IGNORE_INDEX = None
NUM_CLASSES = 4
MODEL_NAME = "tmp.pytorch"
MODEL_SAVE_DIR = "tmp"
LOAD = False
SAVE = False
EPOCHS = 100
WEIGHT_DECAY = 0.01
TARGET_SIZE = 256

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

target_batch_size = 64
batch_size = 64

# With Augmentation
training_data = dataset("datasets/astrain/color", "datasets/astrain/label", target_transform=target_remap())
val_data = dataset("datasets/Val/color", "datasets/Val/label", target_transform=target_remap())
test_data = dataset("datasets/Test/color", "datasets/Test/label", target_transform=target_remap())

train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, pin_memory=True, collate_fn=diff_size_collate)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, pin_memory=True, collate_fn=diff_size_collate)


# Class Weights
class_weight = Tensor([0.33265044664009075, 1.669423957743164, 1.9979255956167454, 0.0])
class_weight = Tensor([0.30711034803008996, 1.5412496145750956, 1.8445296893647247, 0.30711034803008996])
class_weight = Tensor([0.2046795970925636, 1.0271954434416883, 1.2293222812780409, 1.5388026781877073])
# class_weight = [1, 1, 1, 1]
# class_weight = calculate_class_weights(training_data, 4, None, "dataset")
class_weight = class_weight.to(device)

accumulation_steps = target_batch_size // batch_size

# Model
# pretrained_encoder_path = "/content/drive/MyDrive/autoencoder/256_with_aug_LR1e-3/checkpoint_256_with_aug_TargetSize256.pytorch"
# model = SegmentationAutoencoder(3, 4, pretrained_encoder_path).to(device)
model = SegmentationAutoencoder(3, 4).to(device)

# Loses
train_loss_fn = WeightedDiceCELoss(ignore_index=TRAIN_IGNORE_INDEX, smooth_dice=1, class_weights=class_weight)
val_loss_fn = WeightedDiceCELoss(ignore_index=EVAL_IGNORE_INDEX, class_weights=class_weight)

# train_loss_fn = nn.CrossEntropyLoss(weight=class_weight)
# val_loss_fn = nn.CrossEntropyLoss(weight=class_weight, ignore_index=EVAL_IGNORE_INDEX)

# train_loss_fn = nn.CrossEntropyLoss()
# val_loss_fn = nn.CrossEntropyLoss(ignore_index=EVAL_IGNORE_INDEX)

# Optimizer
optimizer = optim.AdamW(model.parameters(), weight_decay=WEIGHT_DECAY)

# Scheduler
scheduler = None

# Metric History
agg = MetricsHistory(NUM_CLASSES, EVAL_IGNORE_INDEX)

# Training Pipiline
start(
    model_save_dir=MODEL_SAVE_DIR,
    model_save_name=MODEL_NAME,
    model=model,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    accumulation_steps=accumulation_steps,
    device=device,
    train_loss_fn=train_loss_fn,
    val_loss_fn=val_loss_fn,
    scheduler=scheduler,
    agg=agg,
    load=LOAD,
    save=SAVE,
    num_classes=NUM_CLASSES,
    ignore_index=EVAL_IGNORE_INDEX,
    target_size=TARGET_SIZE
)

