# Unzip File

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/My Drive/OMSCS/OMSCS_DL_Project/GenImage/BigGAN

In [None]:
! ls

In [None]:
! pwd

In [None]:
! ls /content/sample_data

In [None]:
! unzip unsplit.zip -d /content/sample_data/BigGAN/

In [None]:
# %cp -r /content/sample_data/BigGAN/ /content/drive/MyDrive/OMSCS_DL_Project/GenImage/BigGAN/

In [None]:
# ! zip -F imagenet_ai_0508_adm.zip --out unsplit.zip

## Check Number of files

In [None]:
! ls /content/sample_data/BigGAN/imagenet_ai_0419_biggan/train/ai | wc -l

In [None]:
! ls /content/sample_data/BigGAN/imagenet_ai_0419_biggan/train/nature/ | wc -l

In [None]:
! ls /content/sample_data/BigGAN/imagenet_ai_0419_biggan/val/ai | wc -l

In [None]:
! ls /content/sample_data/BigGAN/imagenet_ai_0419_biggan/val/nature/ | wc -l

Show Image

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
plt.figure()
img = mpimg.imread('/content/sample_data/BigGAN/imagenet_ai_0419_biggan/train/nature/n01582220_4551.JPEG')
imgplot = plt.imshow(img)
plt.axis("off")
plt.show()

# Prepare Dataloader

In [None]:
import os
from os import listdir
from os.path import isfile, join
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision.io import read_image

import torch

from skimage import io, transform

import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
dataset_type = 'val'
model_type = 'nature'
root_dir = '/content/sample_data/BigGAN/imagenet_ai_0419_biggan'

image_name = os.listdir(os.path.join(root_dir, dataset_type, model_type))


In [None]:
image_name = os.path.join(root_dir, dataset_type, model_type,image_name[100])
plt.figure()
img = mpimg.imread(image_name)
imgplot = plt.imshow(img)
plt.axis("off")
plt.show()

In [None]:
# example: plots do not have three layers
# from IPython.display import Image
# Image(filename='/content/sample_data/BigGAN/imagenet_ai_0419_biggan/train/nature/n02092002_1032.JPEG')

In [None]:
%cd /content/drive/My Drive/OMSCS/OMSCS_DL_Project/Deep_Learning_Final_Project/Code

In [None]:
! ls

In [None]:
from data_prep_util import GenImageDataset, Rescale, HighPassFilter, State, CheckPoint

# Define Hyperparameters

In [None]:
# Hyperparameters and configurations
class Config:
    # for data loader
    batch_size = 32
    num_workers = 8

    # number of epochs during training
    num_epochs = 4

    # learning rate for learnable parameters
    learning_rate = 3e-5

    # Define an MLP with 2 or 3 layers
    hidden_dim1 = 500
    hidden_dim2 = 500

    # dropout in head
    dropout = 0.1

    # Set to False to disable the high pass filter
    use_filter = False

    # Adjust alpha between 0 and 1 for the desired effect for the high pass filter
    alpha_value = 0.5

    # Set to True if you want to use pretrained weights
    pretrained = False

# Define Transforms

In [None]:
def get_transformations(use_high_pass_filter=False, alpha_value=0.5, rescale_size = 256):
    transformations = [Rescale(rescale_size)]
    if use_high_pass_filter:
        transformations.append(HighPassFilter(alpha=alpha_value))
    return transforms.Compose(transformations)


In [None]:

dataset_type = 'train'
# model_type = 'nature'
root_dir = '/content/sample_data/BigGAN/imagenet_ai_0419_biggan'

BigGen_train_nature = GenImageDataset(root_dir, dataset_type, 'nature',
                                      transform=get_transformations(Config.use_filter, Config.alpha_value))

BigGen_train_ai = GenImageDataset(root_dir, dataset_type, 'ai',
                                  transform=get_transformations(Config.use_filter, Config.alpha_value))

BigGen_train = torch.utils.data.ConcatDataset([BigGen_train_nature, BigGen_train_ai])

In [None]:
dataset_type = 'val'

BigGen_val_nature = GenImageDataset(root_dir, dataset_type, 'nature',
                                    transform=get_transformations(Config.use_filter, Config.alpha_value))

BigGen_val_ai = GenImageDataset(root_dir, dataset_type, 'ai',
                                transform=get_transformations(Config.use_filter, Config.alpha_value))

BigGen_val = torch.utils.data.ConcatDataset([BigGen_val_nature, BigGen_val_ai])

In [None]:
len(BigGen_train)

In [None]:
len(BigGen_val)

In [None]:
BigGen_train[100]

In [None]:
BigGen_train[100]['image'].shape

In [None]:
plt.figure()
img = BigGen_train[100]['image']
imgplot = plt.imshow(img.permute(1, 2, 0))
plt.axis("off")
plt.show()

In [None]:
next(iter(BigGen_train))

# Create Dataloader

In [None]:
# from torch.utils.data import DataLoader

# train_dataloader = DataLoader(BigGen_train, batch_size=64, shuffle=True)
# val_dataloader = DataLoader(BigGen_val, batch_size=64, shuffle=True)

https://discuss.pytorch.org/t/dataloader-resets-dataset-state/27960

https://discuss.pytorch.org/t/pytorch-dataloaders-in-memory/118471

# Try Swin Transform

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image

In [None]:
import torchvision.models as models
import torch.nn as nn

# Load the pre-trained Swin Transformer model
# model = models.swin_t(weights=models.Swin_T_Weights.IMAGENET1K_V1, progress=True)

# Initialize the Swin Transformer model without pretrained weights
model = models.swin_t(pretrained=Config.pretrained)  # Notice pretrained is set to False


# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

'''
For transfer learning, uncomment to freeze the pretrained weights
'''
# for param in model.parameters():
#     param.requires_grad = False

classes = ['ai','nature']

mlp_head = nn.Sequential(
    nn.Linear(model.head.in_features, Config.hidden_dim1),
    nn.ReLU(),
    nn.Dropout(Config.dropout),
    nn.Linear(Config.hidden_dim1, Config.hidden_dim2),
    nn.ReLU(),
    nn.Dropout(Config.dropout),
    nn.Linear(Config.hidden_dim2, len(classes))
).to(device)

# Update the classifier head to the new MLP
model.head = mlp_head

In [None]:
optimizer = optim.Adam(model.parameters(), lr=Config.learning_rate)
criterion = nn.CrossEntropyLoss()

def accuracy(predictions, labels):
    _, preds = torch.max(predictions, 1)
    return (preds == labels).float().mean().item()

In [None]:
trainloader = DataLoader(BigGen_train, batch_size=Config.batch_size, shuffle=False)
testloader = DataLoader(BigGen_val, batch_size=Config.batch_size, shuffle=False)

In [None]:
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import time  # Import the time module

scaler = GradScaler()  # Initialize the GradScaler

train_loss_history, val_loss_history, train_acc_history, val_acc_history = [], [], [], []
total_training_start = time.time()  # Record the start time of the total training
starting_epoch = 0
load_latest_model = True

state = CheckPoint.load_checkpoint()
if state is not None and load_latest_model:
    print('LOADING FROM CHECKPOINT')
    model.load_state_dict(state.model_state_dict)
    starting_epoch = state.epoch + 1
    trainloader = state.trainloader
    testloader = state.testloader
    train_loss_history = state.train_loss_history
    train_acc_history = state.train_acc_history
    val_loss_history = state.val_loss_history
    val_acc_history = state.val_acc_history
    criterion.load_state_dict(state.criterion_state_dict)
    optimizer.load_state_dict(state.optimizer_state_dict)
    scaler.load_state_dict(state.scaler_state_dict)
    print('starting_epoch', starting_epoch)
    print('train_loss_history', train_loss_history)
    print('train_acc_history', train_acc_history)
    print('val_loss_history', val_loss_history)
    print('val_acc_history', val_acc_history)

for epoch in range(starting_epoch, Config.num_epochs):
    epoch_start = time.time()  # Record the start time of the epoch

    train_loss, train_acc, val_loss, val_acc = 0.0, 0.0, 0.0, 0.0

    # Training Phase
    model.train()
    pbar = tqdm(enumerate(trainloader), total=len(trainloader),
                desc=f"Epoch {epoch+1} TRAIN", ncols=100)
    for i, data in pbar:
        inputs = data['image']
        inputs = inputs.to(torch.float)
        inputs = inputs.to(device)

        labels = data['model_type']
        labels = labels.to(device)

        dataset = data['dataset_type']
        image_name = data['image_name']

        optimizer.zero_grad()

        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        train_acc += accuracy(outputs, labels)

        pbar.set_description(f"Epoch {epoch+1} TRAIN Loss: {loss.item():.4f}")

    # Validation Phase
    model.eval()
    pbar = tqdm(enumerate(testloader), total=len(testloader),
                desc=f"Epoch {epoch+1} VAL", ncols=100)
    with torch.no_grad():
        for i, data in pbar:
            inputs = data['image']
            inputs = inputs.to(torch.float)
            inputs = inputs.to(device)

            labels = data['model_type']
            labels = labels.to(device)

            dataset = data['dataset_type']
            image_name = data['image_name']

            # inputs, labels, dataset = inputs.cuda(), labels.cuda(), dataset.cuda()

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            val_loss += loss.item()
            val_acc += accuracy(outputs, labels)

            pbar.set_description(
                f"Epoch {epoch+1} VAL Loss: {loss.item():.4f}")

    epoch_end = time.time()  # Record the end time of the epoch
    # Calculate the duration in minutes
    epoch_duration = (epoch_end - epoch_start) / 60

    train_loss /= len(trainloader)
    train_acc /= len(trainloader)
    val_loss /= len(testloader)
    val_acc /= len(testloader)

    train_loss_history.append(train_loss)
    train_acc_history.append(train_acc)
    val_loss_history.append(val_loss)
    val_acc_history.append(val_acc)

    print(f"Epoch Summary {epoch+1}/{Config.num_epochs}")
    print(
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc*100:.2f}%")
    print(f"Epoch Duration: {epoch_duration:.2f} minutes")
    print('-' * 60)

    state = State(
        model_state_dict = model.state_dict(),
        epoch = epoch,
        trainloader = trainloader,
        testloader = testloader,
        train_loss_history = train_loss_history,
        train_acc_history = train_acc_history,
        val_loss_history = val_loss_history,
        val_acc_history = val_acc_history,
        criterion_state_dict = criterion.state_dict(),
        optimizer_state_dict = optimizer.state_dict(),
        scaler_state_dict = scaler.state_dict(),
    )
    CheckPoint.save_checkpoint(state)

total_training_end = time.time()  # Record the end time of the total training
# Calculate the total duration in minutes
total_training_duration = (total_training_end - total_training_start) / 60
print('Finished Training')
print(f"Total Training Time: {total_training_duration:.2f} minutes")

# https://stackoverflow.com/questions/59129812/how-to-avoid-cuda-out-of-memory-in-pytorch

In [None]:
# from itertools import islice

# for i,data in enumerate(trainloader):
#   if i>5200:
#     print(data['image'].shape)
#     print(i)


In [None]:
# plots with errors
# /content/sample_data/BigGAN/imagenet_ai_0419_biggan/train/ai/116_biggan_00098.png
# /content/sample_data/BigGAN/imagenet_ai_0419_biggan/train/ai/116_biggan_00107.png
plt.figure()
img = mpimg.imread('/content/sample_data/BigGAN/imagenet_ai_0419_biggan/train/ai/116_biggan_00094.png')
imgplot = plt.imshow(img)
plt.axis("off")
plt.show()

In [None]:
import gc

del model, inputs, labels, dataset, image_name
gc.collect()
torch.cuda.empty_cache()

In [None]:
torch.cuda.memory_summary(device=None, abbreviated=False)