In [None]:
import os
import cv2
import random
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchio as tio
def calculate_alpha(train_dataset):
    """ 
    Calculate the alpha values for the focal loss function.
    The alpha values are inversely proportional to the class frequencies in the dataset.
    
    Ex. 
    class 0 has 90 samples and class 1 has 10 samples
    class frequencies are 90/100 and 10/100
    inverse class frequencies are 100/90 and 100/10
    alpha is proportional to 100/90 and 100/10, so it can be set as [10/90, 90/90] or [0.11, 1]
    
    Args:
    train_dataset (torch.utils.data.Dataset): the training dataset
    """
    smushed_labels = None
    for i in range(len(train_dataset)):
        if i == 0:
            depth, height, width = train_dataset[i][0].shape
        if smushed_labels is None: smushed_labels = train_dataset[i][1].to(torch.int64)
        else: smushed_labels = torch.concat([smushed_labels, train_dataset[i][1].to(torch.int64)])
        print(f"Processed {i+1}/{len(train_dataset)} images", end="\r")
    class_counts = torch.bincount(smushed_labels.flatten())
    total_samples = len(train_dataset) * depth * height * width
    
    w1, w2 = 1/(class_counts[0]/total_samples), 1/(class_counts[1]/total_samples)
    cls_weights = torch.Tensor([w1, w2/9])
    return cls_weights

: 