In [7]:
from dataclasses import dataclass
import torch
from typing import Dict

# Load Data


In [8]:
from dataloader import create_dataloaders
from torchvision.transforms import functional as F
import random

# Constants
PATH = '../data/Splitted CIFAR10.npz'

# Inits
transforms = {
    'random_horizontal_flip': lambda img: F.hflip(img) if random.random() > 0.5 else img,
    'random_vertical_flip': lambda img: F.vflip(img) if random.random() > 0.5 else img,
    'color_jitter': lambda img: F.adjust_brightness(img, brightness_factor=random.uniform(0.8, 1.2)),
    'normalize': lambda img: F.normalize(img, mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
}

@dataclass
class DataConfig:
    npz_path_: str = '../data/Splitted CIFAR10.npz'
    lower_ucc: int = 2
    upper_ucc: int = 4
    bag_size: int = 300
    bag_fraction: float = 0.3
    transform: Dict = None

data_config_test = DataConfig()
dataloaders = create_dataloaders(**data_config_test.__dict__)

# Testing the dataloaders
for images, labels in dataloaders['train']:
    print(f'Images batch shape: {images.shape}')
    print(f'Labels batch shape: {labels.shape}')
    print(f'Labels: {labels}')
    print(f'Images: {images}')
    break

Images batch shape: torch.Size([12, 300, 32, 32, 3])
Labels batch shape: torch.Size([12])
Labels: tensor([2, 4, 2, 4, 3, 2, 2, 4, 3, 3, 2, 4])
Images: tensor([[[[[0.6510, 0.6235, 0.5922],
           [0.6118, 0.5843, 0.5529],
           [0.6941, 0.6667, 0.6353],
           ...,
           [0.7020, 0.6706, 0.6588],
           [0.6824, 0.6510, 0.6392],
           [0.6314, 0.6039, 0.5843]],

          [[0.7686, 0.7373, 0.7059],
           [0.6667, 0.6392, 0.6078],
           [0.6745, 0.6471, 0.6157],
           ...,
           [0.7098, 0.6784, 0.6667],
           [0.6863, 0.6549, 0.6431],
           [0.6235, 0.5961, 0.5765]],

          [[0.7843, 0.7569, 0.7255],
           [0.6588, 0.6314, 0.6000],
           [0.6627, 0.6353, 0.6039],
           ...,
           [0.7137, 0.6824, 0.6706],
           [0.6824, 0.6510, 0.6392],
           [0.6196, 0.5961, 0.5765]],

          ...,

          [[0.5725, 0.5490, 0.4980],
           [0.5059, 0.5137, 0.4549],
           [0.4824, 0.5098, 0.4431],
  

# Define the model

In [9]:
from model import UCCModel

@dataclass
class ModelConfig:
    num_bins: int = 10
    sigma : float = 0.1
    dropout_rate: float = 0.1
    num_classes: int = 10
    embedding_size: int = 110
    fc2_size: int = 512

# Init    
model_config_test = ModelConfig()
model = UCCModel(**model_config_test.__dict__)

# Test

# Mock data
batch_size, num_instances, channels, height, width = 2, 5, 3, 32, 32
random_data = torch.randn((batch_size, num_instances, channels, height, width))

# Forward pass through the model
logits, decoded_imgs = model(random_data)

# Outputs
print("Random Data:", random_data.shape)
print("Logits shape:", logits.shape)
print("Decoded images shape:", decoded_imgs.shape)
print(model)


        x: torch.Size([2, 5, 3, 32, 32]),
        x_flat: torch.Size([10, 3, 32, 32]),
        embedding: torch.Size([10, 110]),
        embeddings_reshaped: torch.Size([2, 5, 110]),
        decoded_img: torch.Size([2, 5, 3, 32, 32]),
        feature_distribution: torch.Size([2, 1100]),
        logits: torch.Size([2, 10])
        
Random Data: torch.Size([2, 5, 3, 32, 32])
Logits shape: torch.Size([2, 10])
Decoded images shape: torch.Size([2, 5, 3, 32, 32])
UCCModel(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(18, 18, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(18, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(18, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    