In [1]:
from datetime import datetime

import pandas as pd
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from era_data import TabletPeriodDataset, get_IDS
from VAE_model_tablets_class import VAE

# Hyperparameters

In [2]:
RUN_NAME_SUFFIX = '-masked_w_classification_loss-equalpartsloss' # ''
IMG_DIR = 'output/images_preprocessed'
LR = 0.0001
EPOCHS = 30
BATCH_SIZE = 8
SUFFIX = '-VAE'
DATE = datetime.now().strftime("%B%d")
BETA=1
IDS = get_IDS(IMG_DIR=IMG_DIR)
print(len(IDS))
VERSION_NAME = f'period_clf_bs{BATCH_SIZE}_lr{LR}_beta_{BETA}_epochs_{EPOCHS}{SUFFIX}-{len(IDS)}_samples{RUN_NAME_SUFFIX}-{DATE}_2'
print(VERSION_NAME)
RESNET_VERNAME = 'period_clf_bs16_lr1e-05_20epochs-resnet50-94936_samples_preprocessed-masked_April16-80-10-10_train_test_val-2' #reading the same train set and test set as previous models

94936
period_clf_bs8_lr0.0001_beta_1_epochs_30-VAE-94936_samples-masked_w_classification_loss-equalpartsloss-April20_2


# Load data

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
train_ids = pd.read_csv(f'output/clf_ids/period-train-{RESNET_VERNAME}.csv', header=None)[0].astype(str)
val_ids = pd.read_csv(f'output/clf_ids/period-val-{RESNET_VERNAME}.csv', header=None)[0].astype(str)
test_ids = pd.read_csv(f'output/clf_ids/period-test-{RESNET_VERNAME}.csv', header=None)[0].astype(str)

In [5]:
ds_train = TabletPeriodDataset(IDS=train_ids, IMG_DIR=IMG_DIR, mask=True)
ds_val = TabletPeriodDataset(IDS=val_ids, IMG_DIR=IMG_DIR, mask=True)
ds_test = TabletPeriodDataset(IDS=test_ids, IMG_DIR=IMG_DIR, mask=True)

Filtering 94936 IDS down to provided 75948...
Filtering 94936 IDS down to provided 9494...
Filtering 94936 IDS down to provided 9494...


In [6]:
def collate_fn(batch):
    data = torch.stack([torch.from_numpy(sample[1]).unsqueeze(0) for sample in batch])
    labels = torch.tensor([sample[2] for sample in batch])

    return data, labels

In [7]:
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, collate_fn = collate_fn, pin_memory=True)
dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, collate_fn = collate_fn, pin_memory=True)
dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, collate_fn = collate_fn, pin_memory=True)

In [8]:
for batch in dl_train:
    x, y = batch
    print(x.shape,y.shape)
    break

torch.Size([8, 1, 512, 512]) torch.Size([8])


# Create Model

In [9]:
logger = pl.loggers.TensorBoardLogger(
    save_dir='.',
    name='lightning_logs',
    version=VERSION_NAME
)

In [10]:
lr_monitor = LearningRateMonitor(logging_interval='step')

In [11]:
early_stop_callback = EarlyStopping(
    monitor='val_total_loss',  
    min_delta=0.000001,      
    patience=5,          
    verbose=10,       
    mode='min',
    check_on_train_epoch_end = True
)

In [12]:
num_classes = len(TabletPeriodDataset.PERIOD_INDICES)
num_classes

22

In [13]:
def compute_class_weights(dataloader, num_classes, epsilon=1e-6):
    class_counts = torch.zeros(num_classes)
    for _, labels in tqdm(dataloader):
        unique, counts = labels.unique(return_counts=True)
        class_counts[unique] += counts
        
    # Compute class proportions
    class_proportions = class_counts / len(dataloader.dataset)
    
    # Inverse the proportions to get class weights and add epsilon to avoid division by zero
    class_weights = 1.0 / (class_proportions + epsilon)
    
    # Normalize the weights so they sum to num_classes
    class_weights = class_weights / class_weights.sum() * num_classes
    
    return class_weights

In [14]:
# class_weights = compute_class_weights(dl_train, num_classes)
# torch.save(class_weights, "data/class_weights_period.pt")

In [15]:
class_weights = torch.load("data/class_weights_period.pt")

In [16]:
model = VAE(image_channels=1, z_dim=12, lr =LR, beta=BETA, use_classification_loss=True, num_classes=num_classes,
            loss_type="weighted", class_weights=class_weights, device = device) # z_dim = size of embeddings bottleneck

  self.class_weights = torch.tensor(class_weights).to(device)


# Train Model

In [17]:
trainer = pl.Trainer(accelerator='gpu', 
                     callbacks=[lr_monitor, early_stop_callback], 
                     max_epochs = EPOCHS, devices='auto', 
                     val_check_interval=0.2, 
                     logger=logger,)  

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, dl_train, dl_val) 

  rank_zero_warn(
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type             | Params
-------------------------------------------------
0 | encoder     | Sequential       | 2.7 M 
1 | fc1         | Linear           | 786 K 
2 | fc2         | Linear           | 786 K 
3 | fc3         | Linear           | 851 K 
4 | decoder     | Sequential       | 2.7 M 
5 | criterion   | CrossEntropyLoss | 0     
6 | fc_classify | Sequential       | 286   
-------------------------------------------------
7.9 M     Trainable params
0         Non-trainable params
7.9 M     Total params
31.430    Total estimated model params siz

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]