In [None]:
import os
import random
from natsort import natsorted
from PIL import Image
from glob import glob
import matplotlib.pyplot as plt
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

### Creating dataset

In [None]:
DATASET_PATH = "dataset/asl_dataset/train"
CLASSES = natsorted(os.listdir(DATASET_PATH))
NUM_CLASSES = len(CLASSES)
print(f"Classes: {CLASSES}")

device = torch.device("cuda")

train_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std= [0.229, 0.224, 0.225]),
        ])

test_transform = transforms.Compose([
            transforms.Resize(256),  # resize to 224x224 because that's the size of ImageNet images
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std= [0.229, 0.224, 0.225]),
        ])

class ASLDataset(Dataset):
    def __init__(self, dataset_path, mode="train", single_class=None, transform=None, generated_path=False):
        self.dataset_path = dataset_path
        if transform:
            self.transform = transform
        else:
            self.transform = train_transform if mode=="train" else test_transform
            
        self.image_paths = []
        self.labels = []

        self.mode = mode
        self.single_class = single_class
        self.generated_path = generated_path

        if not single_class:
            images = glob(os.path.join(dataset_path, "*", "*.png"))
            images = natsorted(images)
        else:
            images = glob(os.path.join(dataset_path, single_class, "*.png"))
            images = natsorted(images)
            
        # if mode is none, use all data
        if self.mode != "none":
            # last 5 frames for test, rest for train
            if self.mode == "train":
                images = images[: -5]
            if self.mode == "test":
                images = images[-5 :]

        for img_path in images:
            
            label = img_path.split('/')[-2]            
            self.image_paths.append(img_path)
            self.labels.append(CLASSES.index(label))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)
        
        return image, label

train_dataset = ASLDataset(DATASET_PATH, mode="none")

# visualize a few samples
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i in range(10):
    rand_sample = random.randint(0, len(train_dataset)-1)
    img, label = train_dataset[rand_sample]
    axes[i//5, i%5].imshow(img.permute(1, 2, 0).numpy() * 0.229 + 0.485)  # unnormalize for display
    axes[i//5, i%5].set_title(CLASSES[label])
    axes[i//5, i%5].axis('off')
plt.show()

### Small VAE

In [None]:
from models import TinyVAE, train_vae, test_vae, vae_loss

def train_vae_model(cfg, training_cfg, ALPHABET_CLASS): 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TinyVAE(cfg)
    model = model.to(device)

    save_dir = 'checkpoints/tiny_vae/'
    os.makedirs(save_dir, exist_ok=True)

    vae_transform = transforms.Compose([
        transforms.Resize(cfg['input_res']), 
        transforms.RandomAffine(degrees=5, translate=(0.05, 0.05)), 
        transforms.ToTensor(),
    ])

    train_dataset = ASLDataset(DATASET_PATH, mode="none", single_class=ALPHABET_CLASS, transform=vae_transform)

    train_dataloader = DataLoader(train_dataset, batch_size=training_cfg['batch_size'], shuffle=True, drop_last=True)
    fid_gt_folder = f'asl_dataset/{ALPHABET_CLASS}'

    num_epochs = training_cfg['num_epochs']
    test_interval = training_cfg['test_interval']

    # TO DO: set initial learning rate
    learn_rate = training_cfg['lr']
    optimizer = torch.optim.AdamW(model.parameters(), lr=learn_rate, eps=1e-15)

    # TO DO: define your learning rate scheduler, e.g. StepLR
    lr_scheduler =  torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

    criterion = vae_loss

    epochs_list = []
    train_recon_losses = []
    train_kl_losses = []
    test_fid_scores = []

    # Iterate over the DataLoader for training data
    for epoch in tqdm(range(num_epochs), total=num_epochs, desc="Training ...", position=1):

        # Compute KL beta
        kl_max = training_cfg['kl_reg']
        def kl_weight(epoch, warmup=100):
            return min(kl_max, epoch / warmup * kl_max)

        # Train the network for one epoch
        loss_dict = train_vae(train_dataloader, model, criterion, optimizer, beta=kl_weight(epoch), alphabet_class=ALPHABET_CLASS, epoch=epoch)

        # Step the learning rate scheduler
        # TO DO (see step() in lr_scheduler)
        lr_scheduler.step()

        # print(f'Loss for Training on epoch {str(epoch)} is {str(train_loss)}')

        # Get the train accuracy and test loss/accuracy
        if(epoch%test_interval==0 or epoch==1 or epoch==num_epochs-1):
            # print('Evaluating Network')

            epochs_list.append(epoch)

            # Get training accuracy and loss
            output_folder = f'training_res/vae/{ALPHABET_CLASS}/samples/epoch_{epoch}'
            fid_score = test_vae(model, output_folder, num_samples=50, compute_fid=training_cfg['compute_fid'], img_folder=fid_gt_folder)
            train_recon_losses.append(loss_dict['recon_loss'])
            train_kl_losses.append(loss_dict['kl_loss'])
            if fid_score is not None:
                test_fid_scores.append(fid_score)

    return model

In [None]:
# training unconditional VAE

ALPHABET_CLASS = '5'

cfg = {
    'img_channels': 3,
    'latent_dim': 32, 
    'enc_sizes': [16, 32, 64, 128, 256],
    'dec_sizes': [256, 128, 64, 32],
    'input_res': 128,
    'is_conditional': False,
    'num_classes': NUM_CLASSES,
}

training_cfg = {
    'batch_size': 16, 
    'num_epochs': 5_000,
    'test_interval': 200,
    'lr': 3e-4,
    'kl_reg': 0.05,
    'compute_fid': False,
}

train_vae_model(cfg, training_cfg, ALPHABET_CLASS)

In [None]:
num_rows = 1
num_cols = 10
num_phases = 5
seed = 0
transition_frames = 25
static_frames = 5
resolution = cfg['input_res']
output = f'transition_{ALPHABET_CLASS}_vae.gif'

np.random.seed(seed)
output_seq = []
batch_size = num_rows * num_cols
latent_size = cfg['latent_dim']
latents = [np.random.randn(batch_size, latent_size) for _ in range(num_phases)]

model = TinyVAE(cfg)
model.load_state_dict(torch.load(f'training_res/vae/{ALPHABET_CLASS}/samples/epoch_{training_cfg["num_epochs"]-1}/model.pth'))
model = model.to(device)
model.eval()

def to_image_grid(outputs):
    outputs = np.reshape(outputs, [num_rows, num_cols, *outputs.shape[1:]])
    outputs = np.concatenate(outputs, axis=1)
    outputs = np.concatenate(outputs, axis=1)
    return Image.fromarray(outputs).resize((resolution * num_cols, resolution * num_rows), Image.LANCZOS)

def generate(dlatents):
    images = model.decode(dlatents.to(device))
    images = (images.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
    return to_image_grid(images)

for i in range(num_phases):
    dlatents0 = torch.from_numpy(latents[i - 1]).to(device).float()
    dlatents1 = torch.from_numpy(latents[i]).to(device).float()

    for j in range(transition_frames):
        dlatents = (dlatents0 * (transition_frames - j) + dlatents1 * j) / transition_frames
        output_seq.append(generate(dlatents))
    output_seq.extend([generate(dlatents1)] * static_frames)

if not output.endswith('.gif'):
    output += '.gif'
output_seq[0].save(output, save_all=True, append_images=output_seq[1:], optimize=False, duration=50, loop=0)


In [None]:
# generate and dump 500 samples from the trained GAN model
output_folder = f'generated_images/vae/{ALPHABET_CLASS}/'
os.makedirs(output_folder, exist_ok=True)
num_samples = 500
model.eval()
with torch.no_grad():
    z = torch.randn(num_samples, cfg['latent_dim'], device=device)
    samples = model.decode(z).cpu()
    samples = (samples.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy()
for i in range(samples.shape[0]):
    img = transforms.ToPILImage()(samples[i])
    img.save(os.path.join(output_folder, f"sample_{i}.png"))

### GAN

In [None]:
from models import Generator, Discriminator, train_gan, test_gan

def train_gan_model(cfg, training_cfg, ALPHABET_CLASS):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Models
    G = Generator(cfg).to(device)
    D = Discriminator(cfg).to(device)

    save_dir = 'checkpoints/gan/'
    os.makedirs(save_dir, exist_ok=True)

    # Dataset
    gan_transform = transforms.Compose([
        transforms.Resize(cfg['input_res']),
        transforms.ToTensor(),
    ])

    train_dataset = ASLDataset(
        DATASET_PATH,
        mode="none",
        single_class=ALPHABET_CLASS,
        transform=gan_transform
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=training_cfg['batch_size'],
        shuffle=True,
        drop_last=True
    )

    fid_gt_folder = f'asl_dataset/{ALPHABET_CLASS}'

    # Optimizers
    g_optimizer = torch.optim.Adam(
        G.parameters(),
        lr=training_cfg['lr'],
        betas=(0.0, 0.99)
    )

    d_optimizer = torch.optim.Adam(
        D.parameters(),
        lr=training_cfg['lr'],
        betas=(0.0, 0.99)
    )

    # Training loop
    num_epochs = training_cfg['num_epochs']
    test_interval = training_cfg['test_interval']
    latent_dim = cfg['latent_dim']

    for epoch in tqdm(range(num_epochs), total=num_epochs, desc="Training GAN ...", position=1):

        loss_dict = train_gan(
            train_loader=train_dataloader,
            G=G,
            D=D,
            g_optimizer=g_optimizer,
            d_optimizer=d_optimizer,
            epoch=epoch,
            latent_dim=latent_dim,
            use_diffaugment=True
        )

        if epoch % test_interval == 0 or epoch == 1 or epoch == num_epochs - 1:

            output_folder = f'training_res/gan/{ALPHABET_CLASS}/samples/epoch_{epoch}'
            fid_score = test_gan(
                G,
                save_img_folder=output_folder,
                latent_dim=latent_dim,
                num_samples=50,
                compute_fid=training_cfg['compute_fid'],
                img_folder=fid_gt_folder
            )

            # Save checkpoints
            torch.save(G.state_dict(), os.path.join(save_dir, f'G_epoch_{epoch}.pth'))
            torch.save(D.state_dict(), os.path.join(save_dir, f'D_epoch_{epoch}.pth'))

    return G

In [None]:
ALPHABET_CLASS = '5'

cfg = {
    'img_channels': 3,
    'latent_dim': 16, 
    'gen_sizes': [512, 256, 128, 64],
    'disc_sizes': [64, 128, 256, 512],
    'input_res': 128,
    'is_conditional': False,
    'num_classes': NUM_CLASSES,
}

training_cfg = {
    'batch_size': 8, # 8 for single class, 128 for conditional VAE
    'num_epochs': 4000, # 4000 for unconditional VAE, 400 for conditional VAE
    'test_interval': 400, # 40 for conditional VAE, 400 for unconditional VAE
    'lr': 1e-4,
    'compute_fid': False,
}

G = train_gan_model(cfg, training_cfg, ALPHABET_CLASS)

In [None]:
num_rows = 1
num_cols = 10
num_phases = 5
seed = 0
transition_frames = 25
static_frames = 5
resolution = cfg['input_res']
output = f'transition_{ALPHABET_CLASS}.gif'

np.random.seed(seed)
output_seq = []
batch_size = num_rows * num_cols
latent_size = cfg['latent_dim']
latents = [np.random.randn(batch_size, latent_size) for _ in range(num_phases)]

G = Generator(cfg).to(device)
G.load_state_dict(torch.load(f'training_res/gan/{ALPHABET_CLASS}/samples/epoch_{training_cfg["num_epochs"]-1}/G.pth'))
G.eval()

def to_image_grid(outputs):
    outputs = np.reshape(outputs, [num_rows, num_cols, *outputs.shape[1:]])
    outputs = np.concatenate(outputs, axis=1)
    outputs = np.concatenate(outputs, axis=1)
    return Image.fromarray(outputs).resize((resolution * num_cols, resolution * num_rows), Image.LANCZOS)

def generate(dlatents, c=None):
    images = G(dlatents.to(device))
    images = (images.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
    return to_image_grid(images)

for i in range(num_phases):
    dlatents0 = torch.from_numpy(latents[i - 1]).to(device).float()
    dlatents1 = torch.from_numpy(latents[i]).to(device).float()

    for j in range(transition_frames):
        dlatents = (dlatents0 * (transition_frames - j) + dlatents1 * j) / transition_frames
        output_seq.append(generate(dlatents))
    output_seq.extend([generate(dlatents1)] * static_frames)

if not output.endswith('.gif'):
    output += '.gif'
output_seq[0].save(output, save_all=True, append_images=output_seq[1:], optimize=False, duration=50, loop=0)

In [None]:
# generate and dump 500 samples from the trained GAN model
output_folder = f'generated_images/ucgan/{ALPHABET_CLASS}/'
os.makedirs(output_folder, exist_ok=True)
num_samples = 500
G.eval()
with torch.no_grad():
    z = torch.randn(num_samples, cfg['latent_dim'], device=device)
    samples = G(z).cpu()
    samples = (samples.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy()
for i in range(samples.shape[0]):
    img = transforms.ToPILImage()(samples[i])
    img.save(os.path.join(output_folder, f"sample_{i}.png"))

### Downstream detection tasks

In [None]:
DATASET_PATH = "dataset/asl_dataset"
CLASSES = natsorted(os.listdir(f'{DATASET_PATH}/train'))
NUM_CLASSES = len(CLASSES)
print(f"Classes: {CLASSES}")


GAN_GEN_DATASET_PATH = "generated_images/ucgan/"
VAE_GEN_DATASET_PATH = "generated_images/vae/"

device = torch.device("cuda")

class HybridASLDataset(Dataset):
    def __init__(self, dataset_path, mode="train", transform=None, gen_mode="none", no_real_data=False):
        
        self.dataset_path = dataset_path

        print(f"Initializing HybridASLDataset with gen_mode={gen_mode}, no_real_data={no_real_data}, mode={mode}")

        assert gen_mode in ["none", "ucgan", "vae"], "gen_mode must be one of 'none', 'ucgan', or 'vae'"
        assert mode in ["train", "test"], "mode must be either 'train' or 'test'"
        assert not (no_real_data and mode!="train"), "no_real_data can only be True in train mode"

        self.gen_dataset_path = None
        if gen_mode == "ucgan":
            self.gen_dataset_path = GAN_GEN_DATASET_PATH
        elif gen_mode == "vae":
            self.gen_dataset_path = VAE_GEN_DATASET_PATH

        if transform:
            self.transform = transform
        else:
            self.transform = train_transform if mode=="train" else test_transform
            
        self.image_paths = []
        self.labels = []
        self.is_generated = []

        self.mode = mode
        self.gen_mode = gen_mode
        self.no_real_data = no_real_data

        images = glob(os.path.join(dataset_path, mode, "*", "*.png"))
        images = natsorted(images)

        if not self.no_real_data:
            for img_path in images:
                label = img_path.split('/')[-2]            
                self.image_paths.append(img_path)
                self.labels.append(CLASSES.index(label))
                self.is_generated.append(0)

        # generated images for train only
        if self.gen_dataset_path and self.mode == "train":
            for class_name in CLASSES:
                gen_images = glob(os.path.join(self.gen_dataset_path, class_name, "*.png"))
                for img_path in gen_images:
                    label = class_name
                    self.image_paths.append(img_path)
                    self.labels.append(CLASSES.index(label))
                    self.is_generated.append(1)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        is_generated = self.is_generated[idx]
        
        return image, label, is_generated

train_dataset = HybridASLDataset(DATASET_PATH, mode="train", gen_mode="none")
print(DATASET_PATH)

# visualize a few samples
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i in range(10):
    rand_sample = random.randint(0, len(train_dataset)-1)
    img, label, _ = train_dataset[rand_sample]
    axes[i//5, i%5].imshow(img.permute(1, 2, 0).numpy() * 0.229 + 0.485)  # unnormalize for display
    axes[i//5, i%5].set_title(CLASSES[label])
    axes[i//5, i%5].axis('off')
plt.show()

In [None]:
from models import ResNetModel, SimpleCNN, train_clas, test_clas, save_checkpoint, load_model, plot_results


def train_classification_model(model_type="tinycnn", data_regime="mixed", gen_model="ucgan", training_cfg=None):
    device = torch.device("cuda")

    assert model_type in ["tinycnn", "resnet"], "model_type must be either 'tinycnn' or 'resnet'"
    assert data_regime in ["real_only", "gen_only", "mixed"], "data_regime must be one of 'real_only', 'gen_only', or 'mixed'"
    data_kwargs = {}
    if data_regime == "real_only":
        data_kwargs['no_real_data'] = False
        data_kwargs['gen_mode'] = "none"
    elif data_regime == "gen_only":
        data_kwargs['no_real_data'] = True
        data_kwargs['gen_mode'] = gen_model
    elif data_regime == "mixed":
        data_kwargs['no_real_data'] = False
        data_kwargs['gen_mode'] = gen_model

    train_transform = transforms.Compose([
        transforms.Resize(128),
        transforms.ToTensor(),

    ])

    test_transform = transforms.Compose([
        transforms.Resize(128),  # resize to 128x128
        transforms.ToTensor(),
    ])

    train_dataset = HybridASLDataset(
        DATASET_PATH,
        mode="train",
        transform=train_transform,
        **data_kwargs
    )

    eval_dataset = HybridASLDataset(
        DATASET_PATH,
        mode="test",
        transform=test_transform,
        gen_mode="none"
    )

    model_arch = ResNetModel if model_type == "resnet" else SimpleCNN
    model = model_arch(NUM_CLASSES)
    model = model.to(device)

    save_dir = f'checkpoints/classification/{model_type}_{data_regime}_{gen_model}/'
    os.makedirs(save_dir, exist_ok=True)

    train_dataloader = DataLoader(train_dataset, batch_size=training_cfg['batch_size'], shuffle=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=16, shuffle=False)

    num_epochs = training_cfg['num_epochs']
    test_interval = training_cfg['test_interval']

    # TO DO: set initial learning rate
    learn_rate = training_cfg['lr']
    optimizer = torch.optim.AdamW(model.parameters(), lr=learn_rate, eps=1e-15)

    # TO DO: define your learning rate scheduler, e.g. StepLR
    # https://pytorch.org/docs/stable/optim.html#module-torch.optim.lr_scheduler
    lr_scheduler =  torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

    criterion = torch.nn.CrossEntropyLoss(reduce='none')

    epochs_list = []
    train_losses = []
    train_accuracy_list = []
    test_losses = []
    test_accuracy_list = []


    # Iterate over the DataLoader for training data
    for epoch in tqdm(range(num_epochs), total=num_epochs, desc="Training ...", position=1):

        # Train the network for one epoch
        train_loss = train_clas(train_dataloader, model, criterion, optimizer, epoch, fake_weight_max=0.1)

        # Step the learning rate scheduler
        lr_scheduler.step()

        # print(f'Loss for Training on epoch {str(epoch)} is {str(train_loss)}')

        # Get the train accuracy and test loss/accuracy
        if(epoch%test_interval==0 or epoch==1 or epoch==num_epochs-1):
            print('Evaluating Network')

            epochs_list.append(epoch)

            # Get training accuracy and loss
            train_accuracy, train_loss = test_clas(train_dataloader, model, criterion)
            train_losses.append(train_loss)
            train_accuracy_list.append(train_accuracy)

            print(f'Training accuracy on epoch {str(epoch)} is {str(train_accuracy)}')

            # Get test accuracy and loss (use test loader)
            # TO DO
            test_accuracy, test_loss = test_clas(eval_dataloader, model, criterion)
            test_losses.append(test_loss)
            test_accuracy_list.append(test_accuracy)

            print(f'Test (val) accuracy on epoch {str(epoch)} is {str(test_accuracy)}')

            # Checkpoints are used to save the model with best validation accuracy
            if test_accuracy >= max(test_accuracy_list):
              print("Saving Model")
              save_checkpoint(save_dir, model, save_name = 'best_model.pth') # Save model with best performance

    return model, epochs_list, train_losses, train_accuracy_list, test_losses, test_accuracy_list

In [None]:
cfg = {
    'model_type': 'tinycnn',
    'data_regime': 'real_only',
    'gen_model': 'none',
}

training_cfg = {
    'batch_size': 128, 
    'num_epochs': 100,
    'test_interval': 20,
    'lr': 5e-5,
}

model, epochs_list, train_losses, train_accuracy_list, test_losses, test_accuracy_list = \
  train_classification_model(training_cfg=training_cfg, **cfg)

# plot
train_accuracy_list_items = [x.item() for x in train_accuracy_list]
test_accuracy_list_items = [x.item() for x in test_accuracy_list]
plot_results(epochs_list, train_accuracy_list_items, test_accuracy_list_items, "Accuracy")