In [1]:
from dataclasses import dataclass
import torch
from typing import Dict
import numpy as np
import os

# Load Data


In [2]:
from dataloader import create_dataloaders
from torchvision.transforms import functional as TF
import random

# Constants
PATH = '../data/Splitted CIFAR10.npz'

# Inits
transforms = {
    'random_horizontal_flip': lambda img: TF.hflip(img) if random.random() > 0.5 else img,
    'random_vertical_flip': lambda img: TF.vflip(img) if random.random() > 0.5 else img,
    'color_jitter': lambda img: TF.adjust_brightness(img, brightness_factor=random.uniform(0.8, 1.2)),
    'normalize': lambda img: TF.normalize(img, mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
}

@dataclass
class DataConfig:
    npz_path_: str = '../data/Splitted CIFAR10.npz'
    lower_ucc: int = 2
    upper_ucc: int = 4
    bag_size: int = 300
    bag_fraction: float = 0.3
    batch_size: int = 32
    transform: Dict = None

data_config_test = DataConfig()
dataloaders = create_dataloaders(**data_config_test.__dict__)

# Testing the dataloaders
for images, labels in dataloaders['train']:
    print(f'Images batch shape: {images.shape}')
    print(f'Labels batch shape: {labels.shape}')
    print(f'Labels: {labels}')
    print(f'Images: {images}')
    break

Images batch shape: torch.Size([32, 300, 32, 32, 3])
Labels batch shape: torch.Size([32])
Labels: tensor([4, 4, 4, 4, 3, 3, 3, 4, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 4, 3,
        4, 4, 2, 4, 4, 2, 4, 4])
Images: tensor([[[[[0.4157, 0.4353, 0.3569],
           [0.2431, 0.2510, 0.1765],
           [0.3333, 0.3412, 0.2706],
           ...,
           [0.3882, 0.3804, 0.2941],
           [0.3412, 0.3412, 0.2510],
           [0.2784, 0.2824, 0.1961]],

          [[0.4235, 0.4431, 0.3647],
           [0.2510, 0.2588, 0.1843],
           [0.3176, 0.3255, 0.2549],
           ...,
           [0.3569, 0.3373, 0.2353],
           [0.2431, 0.2314, 0.1294],
           [0.2549, 0.2392, 0.1451]],

          [[0.4353, 0.4588, 0.3765],
           [0.2549, 0.2627, 0.1922],
           [0.2941, 0.3059, 0.2314],
           ...,
           [0.3412, 0.3216, 0.2118],
           [0.3098, 0.2980, 0.1922],
           [0.3373, 0.3216, 0.2235]],

          ...,

          [[0.0627, 0.0588, 0.0431],
        

# Define the model

In [3]:
from model import UCCModel

@dataclass
class ModelConfig:
    num_bins: int = 10
    sigma : float = 0.1
    dropout_rate: float = 0.1
    num_classes: int = 10
    embedding_size: int = 110
    fc2_size: int = 512

# Init    
model_config_test = ModelConfig()
model = UCCModel(**model_config_test.__dict__)

# Test

# Mock data
batch_size, num_instances, channels, height, width = 2, 5, 3, 32, 32
random_data = torch.randn((batch_size, num_instances, channels, height, width))

# Forward pass through the model
logits, decoded_imgs = model(random_data)

# Outputs
print("Random Data:", random_data.shape)
print("Logits shape:", logits.shape)
print("Decoded images shape:", decoded_imgs.shape)
print(model)

Random Data: torch.Size([2, 5, 3, 32, 32])
Logits shape: torch.Size([2, 10])
Decoded images shape: torch.Size([2, 5, 3, 32, 32])
UCCModel(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(18, 18, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(18, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(18, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (4): Sequential(
      (0): Conv2d(9, 9, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (5): Flatten(start_dim=1, end_dim=-1)
    (6): Linear(in_features=576, out_features=576, bias=True)
    (7): ReLU()
    (8): Linear(in_features=576, out_features=288,

# Training Functions

In [4]:
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

def combined_loss_fn(logits, decoded_img, labels, original_imgs, ucc_loss_weight=0.5):
    
    ae_loss_weight = 1 - ucc_loss_weight
    
    ucc_loss = F.cross_entropy(logits, labels)
    ae_loss = F.mse_loss(decoded_img, original_imgs)
    combined_loss = (ucc_loss_weight * ucc_loss) + (ae_loss_weight * ae_loss)

    return combined_loss

def eval(model, val_loader, device, combined_loss_fn, ucc_loss_weight=0.5):
    model.eval()
    val_loss_list = []
    val_acc_list = []
    
    with torch.no_grad():
        for batch_samples, batch_labels in val_loader:
            batch_samples = batch_samples.to(device)
            batch_labels = batch_labels.to(device)
            
            # Get model outputs
            ucc_logits, decoded_imgs = model(batch_samples)
            
            # Compute metrics
            ucc_val_loss = combined_loss_fn(ucc_logits, decoded_imgs, batch_labels, batch_samples)

            _, ucc_predicts = torch.max(ucc_logits, dim=1)
            acc = (ucc_predicts == batch_labels).float().mean().item()
            
            val_acc_list.append(acc)
            val_loss_list.append(ucc_val_loss.item())
    
    return np.mean(val_loss_list), np.mean(val_acc_list)

def train(train_loader, val_loader, model, optimizer, combined_loss_fn, model_name, model_dir, device, save_interval, num_epochs, ucc_loss_weight=0.5):
    model.train()
    
    step = 0
    best_eval_acc = 0
    
    for epoch in range(num_epochs):
        print(f"Epoch: {epoch}, Step: {step}")
        for batch_samples, batch_labels in tqdm(train_loader):
            # Load Data
            batch_samples = batch_samples.to(device)
            batch_labels = batch_labels.to(device)

            # DEBUG
            # print(f"Batch Samples: {batch_samples.shape}")
            # print(f"Batch Labels: {batch_labels.shape}")

            # Step 
            optimizer.zero_grad()
            logits, decoded_imgs = model(batch_samples)
            loss = combined_loss_fn(logits, decoded_imgs, batch_labels, batch_samples)
            loss.backward()
        
            optimizer.step()
            step += 1

            if step % save_interval == 0:
                eval_loss, eval_acc = eval(model, val_loader, device, combined_loss_fn, ucc_loss_weight)
                print(f"Epoch: {epoch}, Step: {step}, Eval Loss: {eval_loss}, Eval Acc: {eval_acc}")
                # Early stopping check and model save
                if eval_acc > best_eval_acc:
                    best_eval_acc = eval_acc
                    save_path = os.path.join(model_dir, f"{model_name}_best.pth")
                    save_dict = {
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "eval_loss": eval_loss,
                        "eval_acc": eval_acc,
                        "step": step,
                    }
                    torch.save(save_dict, save_path)
                model.train()
    print("# Training Done\n############################################")
    print(f"Best Eval Acc: {best_eval_acc}")
    print(f"Saved model to {model_dir}")

from dataclasses import dataclass
import os

In [5]:
from train import Trainer

@dataclass
class TrainConfig:
    model: nn.Module
    optimizer: torch.optim.Optimizer
    train_loader: DataLoader
    val_loader: DataLoader
    model_name: str
    total_steps: int = 10_000
    eval_interval: int = 100
    ucc_loss_weight: float = 0.5
    model_dir: str = "./models"
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



# Training
- UCC 1 to 4
- w autoencoder


In [6]:
############################################
# CONFIGS
############################################

DATA_CONFIG_1 = DataConfig(
    npz_path_ = PATH,
    lower_ucc = 1,
    upper_ucc = 4,
    bag_size = 200,
    bag_fraction = 0.2,
    batch_size = 10,
    transform = transforms
)
MODEL_CONFIG_1 = ModelConfig(
    num_bins = 10,
    sigma = 0.1,
    dropout_rate = 0.1,
    num_classes = 10,
    embedding_size = 110,
    fc2_size = 512
)
LEARNING_RATE = 0.0001

############################################
# Instantiate Parts
############################################

# data
dataloaders = create_dataloaders(**DATA_CONFIG_1.__dict__)

# model
model = UCCModel(**MODEL_CONFIG_1.__dict__)

# Optimizer 
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

############################################
# Train
############################################

TRAIN_CONFIG_1 = TrainConfig(
    model=model,
    optimizer=optimizer,
    train_loader=dataloaders['train'],
    val_loader=dataloaders['val'],
    model_name="test_model",
    total_steps=100,
    eval_interval=5,
    ucc_loss_weight=0.5,
    model_dir="../models",
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
trainer = Trainer(**TRAIN_CONFIG_1.__dict__)
trainer.train()


###############
# Starting Training...
###############
Step 5 | Train Loss: 2.0008479833602903, Val Loss: 1.4848574250936508, Val Acc: 0.33035715110599995


KeyboardInterrupt: 

In [None]:
"""
UCC

"""
############################################
# CONFIGS
############################################

DATA_CONFIG_1 = DataConfig(
    npz_path_ = PATH,
    lower_ucc = 1,
    upper_ucc = 4,
    bag_size = 200,
    bag_fraction = 0.2,
    batch_size = 10,
    transform = transforms
)
MODEL_CONFIG_1 = ModelConfig(
    num_bins = 10,
    sigma = 0.1,
    dropout_rate = 0.1,
    num_classes = 10,
    embedding_size = 110,
    fc2_size = 512
)
LEARNING_RATE = 0.0001

############################################
# Instantiate Parts
############################################

dataloaders = create_dataloaders(**DATA_CONFIG_1.__dict__)
model = UCCModel(**MODEL_CONFIG_1.__dict__)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

############################################
# Train
############################################
MODEL_NAME = "model1"
TRAIN_CONFIG_1 = TrainConfig(
    model=model,
    optimizer=optimizer,
    train_loader=dataloaders['train'],
    val_loader=dataloaders['val'],
    model_name=MODEL_NAME,
    total_steps=100,
    eval_interval=5,
    ucc_loss_weight=0.5,
    model_dir="../models",
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)

trainer = Trainer(**TRAIN_CONFIG_1.__dict__)
trainer.train()
