# Unzip File

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/My Drive/OMSCS/OMSCS_DL_Project/GenImage/BigGAN

In [None]:
! ls

In [None]:
! pwd

In [None]:
! ls /content/sample_data

In [None]:
! unzip unsplit.zip -d /content/sample_data/BigGAN/

In [None]:
# %cp -r /content/sample_data/BigGAN/ /content/drive/MyDrive/OMSCS_DL_Project/GenImage/BigGAN/

In [None]:
# ! zip -F imagenet_ai_0508_adm.zip --out unsplit.zip

## Check Number of files

In [None]:
! ls /content/sample_data/BigGAN/imagenet_ai_0419_biggan/train/ai | wc -l

In [None]:
! ls /content/sample_data/BigGAN/imagenet_ai_0419_biggan/train/nature/ | wc -l

In [None]:
! ls /content/sample_data/BigGAN/imagenet_ai_0419_biggan/val/ai | wc -l

In [None]:
! ls /content/sample_data/BigGAN/imagenet_ai_0419_biggan/val/nature/ | wc -l

Show Image

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
plt.figure()
img = mpimg.imread('/content/sample_data/BigGAN/imagenet_ai_0419_biggan/train/nature/n01582220_4551.JPEG')
imgplot = plt.imshow(img)
plt.axis("off")
plt.show()

# Prepare Dataloader

In [None]:
import os
from os import listdir
from os.path import isfile, join
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision.io import read_image

import torch

from skimage import io, transform

import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
dataset_type = 'val'
model_type = 'nature'
root_dir = '/content/sample_data/BigGAN/imagenet_ai_0419_biggan'

image_name = os.listdir(os.path.join(root_dir, dataset_type, model_type))


In [None]:
image_name = os.path.join(root_dir, dataset_type, model_type,image_name[100])
plt.figure()
img = mpimg.imread(image_name)
imgplot = plt.imshow(img)
plt.axis("off")
plt.show()

In [None]:
%cd /content/drive/My Drive/OMSCS/OMSCS_DL_Project/Deep_Learning_Final_Project/Code

In [None]:
! ls

In [None]:
from data_prep_util import GenImageDataset, Rescale, HighPassConvLayer, State, CheckPoint
from Entropy import Entropy

# Define Hyperparameters

In [None]:
# Hyperparameters and configurations
class Config:
    # for data loader
    batch_size = 32
    num_workers = 8

    # number of epochs during training
    num_epochs = 4

    # learning rate for learnable parameters
    learning_rate = 3e-5

    # Define an MLP with 2 or 3 layers
    hidden_dim1 = 500
    hidden_dim2 = 500

    # dropout in head
    dropout = 0.1

    # Set to False to disable the high pass filter
    use_filter = True

    # Adjust alpha between 0 and 1 for the desired effect for the high pass filter
    alpha_value = 0.5

    # Set to True if you want to use pretrained weights
    pretrained = False

    # Set to True if you want to use the entropy filter
    use_entropy_filter = False

    def print_values():
        print('batch_size:', Config.batch_size)
        print('num_workers:', Config.num_workers)
        print('num_epochs:', Config.num_epochs)
        print('learning_rate:', Config.learning_rate)
        print('hidden_dim1:', Config.hidden_dim1)
        print('hidden_dim2:', Config.hidden_dim2)
        print('dropout:', Config.dropout)
        print('use_filter:', Config.use_filter)
        print('alpha_value:', Config.alpha_value)
        print('pretrained:', Config.pretrained)
        print('use_entropy_filter:', Config.use_entropy_filter)

# Define Transforms

In [None]:
def get_transformations(rescale_size=256):
    transformations = [Rescale(rescale_size)]
    return transforms.Compose(transformations)

In [None]:

dataset_type = 'train'
# model_type = 'nature'
root_dir = '/content/sample_data/BigGAN/imagenet_ai_0419_biggan'

train_nature = GenImageDataset(root_dir, dataset_type, 'nature',
                                      transform=get_transformations(), input_type='Image')

train_ai = GenImageDataset(root_dir, dataset_type, 'ai',
                                  transform=get_transformations(), input_type='Image')

train = torch.utils.data.ConcatDataset(
    [train_nature, train_ai])

In [None]:
dataset_type = 'val'

val_nature = GenImageDataset(root_dir, dataset_type, 'nature',
                                    transform=get_transformations(), input_type='Image')

val_ai = GenImageDataset(root_dir, dataset_type, 'ai',
                                transform=get_transformations(), input_type='Image')

val = torch.utils.data.ConcatDataset([val_nature, val_ai])

In [None]:
len(train)

In [None]:
len(val)

In [None]:
train[100]

In [None]:
train[100]['image'].shape

In [None]:
plt.figure()
img = train[100]['image']
imgplot = plt.imshow(img.permute(1, 2, 0))
plt.axis("off")
plt.show()

In [None]:
next(iter(train))

https://discuss.pytorch.org/t/dataloader-resets-dataset-state/27960

https://discuss.pytorch.org/t/pytorch-dataloaders-in-memory/118471

# Define Swin Transformer

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image

In [None]:
# Wrapper for Swin Transformer to allow optional conv layer
class CustomSwinModel(nn.Module):
    def __init__(self, base_model, use_high_pass_filter=False):
        super(CustomSwinModel, self).__init__()
        self.use_high_pass_filter = use_high_pass_filter
        self.high_pass_filter = HighPassConvLayer(
        ) if use_high_pass_filter else nn.Identity()
        self.base_model = base_model

    def forward(self, x):
        if self.use_high_pass_filter:
            x = self.high_pass_filter(x)
        return self.base_model(x)


# Initialize the Swin Transformer model without pretrained weights
# Notice pretrained is set to False
base_model = models.swin_t(pretrained=Config.pretrained)

# Create the custom model with the high-pass filter layer
model = CustomSwinModel(base_model, use_high_pass_filter=Config.use_filter)


# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

'''
For transfer learning, uncomment to freeze the pretrained weights
'''
# for param in model.parameters():
#     param.requires_grad = False

classes = ['ai', 'nature']

mlp_head = nn.Sequential(
    nn.Linear(model.base_model.head.in_features, Config.hidden_dim1),
    nn.ReLU(),
    nn.Dropout(Config.dropout),
    nn.Linear(Config.hidden_dim1, Config.hidden_dim2),
    nn.ReLU(),
    nn.Dropout(Config.dropout),
    nn.Linear(Config.hidden_dim2, len(classes))
).to(device)

# Update the classifier head of the base_model inside CustomSwinModel
model.base_model.head = mlp_head

In [None]:
optimizer = optim.Adam(model.parameters(), lr=Config.learning_rate)
criterion = nn.CrossEntropyLoss()

def accuracy(predictions, labels):
    _, preds = torch.max(predictions, 1)
    return (preds == labels).float().mean().item()

In [None]:
trainloader = DataLoader(train, batch_size=Config.batch_size, shuffle=True)
testloader = DataLoader(val, batch_size=Config.batch_size, shuffle=True)

# Experiment 1: HighPass Filter 

> *the following code block contains visualizations related to experiment #1*

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.fft import fft2, fftshift

Config.use_filter = True
Config.use_entropy_filter = False
Config.pretrained = False


def apply_high_pass_filter(image, filter_layer):
    # Convert the image to a floating-point format and add a batch dimension
    image_batch = image.unsqueeze(0).to(torch.float32).to(device)

    # Apply the high pass filter
    filtered_image_batch = filter_layer(image_batch)

    # Remove the batch dimension
    filtered_image = filtered_image_batch.squeeze(0)

    return filtered_image


def plot_frequency_spectrum(image, title):
    # Convert the image to grayscale if it's not already
    if image.shape[0] == 3:
        image = image.mean(0)  # Average across the color channels

    # Apply FFT
    f_image = fft2(image)
    fshift = fftshift(f_image)

    # Calculate magnitude spectrum and use log scale for better visibility
    magnitude_spectrum = torch.log(torch.abs(fshift) + 1)

    # Display the spectrum
    plt.imshow(magnitude_spectrum.cpu().numpy(), cmap='gray')
    plt.title(title)
    plt.axis('off')


# Instantiate the high-pass filter layer
high_pass_filter_layer = HighPassConvLayer().to(device)

# Get an image from the dataset and convert to float if necessary
original_image = train[0]['image']
if original_image.dtype == torch.uint8:
    original_image = original_image.to(torch.float32) / 255.

# Apply the high pass filter
filtered_image = apply_high_pass_filter(original_image, high_pass_filter_layer)

# Convert filtered image to float if necessary (the filter should already output float, but just in case)
if filtered_image.dtype == torch.uint8:
    filtered_image = filtered_image.to(torch.float32) / 255.

# Display the original and filtered images and their frequency spectrums
plt.figure(figsize=(12, 12))

# Original image
plt.subplot(2, 2, 1)
# Move to CPU for visualization
plt.imshow(original_image.permute(1, 2, 0).cpu().numpy())
plt.title("Original Image")
plt.axis("off")

# Frequency spectrum of the original image
plt.subplot(2, 2, 2)
plot_frequency_spectrum(original_image, "Original Image Frequency Spectrum")

# Filtered image
plt.subplot(2, 2, 3)
# Move to CPU for visualization
plt.imshow(filtered_image.permute(1, 2, 0).cpu().numpy())
plt.title("Filtered Image")
plt.axis("off")

# Frequency spectrum of the filtered image
plt.subplot(2, 2, 4)
plot_frequency_spectrum(filtered_image, "Filtered Image Frequency Spectrum")

plt.tight_layout()
plt.show()

# Experiment 3 Setup: Entropy Filter 

> *run the following block before the training loop to perform experiment 3*

In [None]:
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

Config.use_filter = False
Config.use_entropy_filter = True
Config.pretrained = False
Config.print_values()
base_model = models.swin_t(pretrained=Config.pretrained)
model = CustomSwinModel(base_model, use_high_pass_filter=Config.use_filter)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
num_classes = 2
mlp_head = nn.Sequential(
    nn.Linear(model.base_model.head.in_features, Config.hidden_dim1),
    nn.ReLU(),
    nn.Dropout(Config.dropout),
    nn.Linear(Config.hidden_dim1, Config.hidden_dim2),
    nn.ReLU(),
    nn.Dropout(Config.dropout),
    nn.Linear(Config.hidden_dim2, num_classes)
).to(device)

model.base_model.head = mlp_head
optimizer = optim.Adam(model.parameters(), lr=Config.learning_rate)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

original_image = train[1]['image']
plt.imshow(original_image.permute(1, 2, 0))
plt.show()

entropy_filtered = Entropy.entropy_for_image(None, original_image)
plt.imshow(entropy_filtered[0], cmap='gray')
plt.show()

# Training Loop

In [None]:
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import time  # Import the time module

scaler = GradScaler()  # Initialize the GradScaler

train_loss_history, val_loss_history, train_acc_history, val_acc_history = {}, {}, {}, {}
total_training_start = time.time()  # Record the start time of the total training
starting_epoch = 0
load_latest_model = False

state = CheckPoint.load_checkpoint()
if state is not None and load_latest_model:
    print('LOADING FROM CHECKPOINT')
    model.load_state_dict(state.model_state_dict)
    starting_epoch = state.epoch + 1
    trainloader = state.trainloader
    testloader = state.testloader
    train_loss_history = state.train_loss_history
    train_acc_history = state.train_acc_history
    val_loss_history = state.val_loss_history
    val_acc_history = state.val_acc_history
    criterion.load_state_dict(state.criterion_state_dict)
    optimizer.load_state_dict(state.optimizer_state_dict)
    scaler.load_state_dict(state.scaler_state_dict)
    print('starting_epoch', starting_epoch)
    print('train_loss_history', train_loss_history)
    print('train_acc_history', train_acc_history)
    print('val_loss_history', val_loss_history)
    print('val_acc_history', val_acc_history)

for epoch in range(starting_epoch, Config.num_epochs):
    epoch_start = time.time()  # Record the start time of the epoch

    train_loss, train_acc, val_loss, val_acc = 0.0, 0.0, 0.0, 0.0
    train_loss_history[epoch] = []
    val_loss_history[epoch] = []
    train_acc_history[epoch] = []
    val_acc_history[epoch] = []

    # Training Phase
    model.train()
    pbar = tqdm(enumerate(trainloader), total=len(trainloader),
                desc=f"Epoch {epoch+1} TRAIN", ncols=100)
    for i, data in pbar:
        inputs = data['image']
        if Config.use_entropy_filter:
            numpy_inputs = inputs.numpy()
            processed_inputs = []
            for image in numpy_inputs:
                processed_inputs.append(Entropy.entropy_for_image(None, image))
            inputs = torch.stack(processed_inputs)
        
        inputs = inputs.to(torch.float)
        inputs = inputs.to(device)

        labels = data['model_type']
        labels = labels.to(device)

        dataset = data['dataset_type']
        image_name = data['image_name']

        optimizer.zero_grad()

        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        acc = accuracy(outputs, labels)
        train_loss_history[epoch].append(loss.item())
        train_acc_history[epoch].append(acc)

        train_loss += loss.item()
        train_acc += acc

        pbar.set_description(f"Epoch {epoch+1} TRAIN Loss: {loss.item():.4f}")

    # Validation Phase
    model.eval()
    pbar = tqdm(enumerate(testloader), total=len(testloader),
                desc=f"Epoch {epoch+1} VAL", ncols=100)
    with torch.no_grad():
        for i, data in pbar:
            inputs = data['image']
            inputs = inputs.to(torch.float)
            inputs = inputs.to(device)

            labels = data['model_type']
            labels = labels.to(device)

            dataset = data['dataset_type']
            image_name = data['image_name']

            # inputs, labels, dataset = inputs.cuda(), labels.cuda(), dataset.cuda()

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            acc = accuracy(outputs, labels)
            val_loss_history[epoch].append(loss.item())
            val_acc_history[epoch].append(acc)

            val_loss += loss.item()
            val_acc += acc

            pbar.set_description(
                f"Epoch {epoch+1} VAL Loss: {loss.item():.4f}")

    epoch_end = time.time()  # Record the end time of the epoch
    # Calculate the duration in minutes
    epoch_duration = (epoch_end - epoch_start) / 60

    train_loss /= len(trainloader)
    train_acc /= len(trainloader)
    val_loss /= len(testloader)
    val_acc /= len(testloader)

    print(f"Epoch Summary {epoch+1}/{Config.num_epochs}")
    print(
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc*100:.2f}%")
    print(f"Epoch Duration: {epoch_duration:.2f} minutes")
    print('-' * 60)

    state = State(
        model_state_dict=model.state_dict(),
        epoch=epoch,
        trainloader=trainloader,
        testloader=testloader,
        train_loss_history=train_loss_history,
        train_acc_history=train_acc_history,
        val_loss_history=val_loss_history,
        val_acc_history=val_acc_history,
        criterion_state_dict=criterion.state_dict(),
        optimizer_state_dict=optimizer.state_dict(),
        scaler_state_dict=scaler.state_dict(),
    )
    CheckPoint.save_checkpoint(state)

total_training_end = time.time()  # Record the end time of the total training
# Calculate the total duration in minutes
total_training_duration = (total_training_end - total_training_start) / 60
print('Finished Training')
print(f"Total Training Time: {total_training_duration:.2f} minutes")

# https://stackoverflow.com/questions/59129812/how-to-avoid-cuda-out-of-memory-in-pytorch

# Plot Results

In [None]:
import matplotlib.pyplot as plt

training_losses = []
training_acc = []
validation_losses = []
validation_acc = []

# combine loss maps into a single array
for epoch in train_loss_history:
  training_losses += train_loss_history[epoch]
  training_acc += train_acc_history[epoch]
  validation_losses += val_loss_history[epoch]
  validation_acc += val_acc_history[epoch]
  
plt.plot(training_losses, label='train')
plt.plot(validation_losses, label='validation')
plt.legend(loc='upper left')
plt.xlabel('iteration')
plt.ylabel('loss')
plt.savefig('losses.png')

plt.clf()

plt.plot(training_acc, label='train')
plt.plot(validation_acc, label='validation')
plt.legend(loc='upper left')
plt.xlabel('iteration')
plt.ylabel('accuracy')
plt.savefig('accuracies.png')

In [None]:
# from itertools import islice

# for i,data in enumerate(trainloader):
#   if i>5200:
#     print(data['image'].shape)
#     print(i)


In [None]:
# plots with errors
# /content/sample_data/BigGAN/imagenet_ai_0419_biggan/train/ai/116_biggan_00098.png
# /content/sample_data/BigGAN/imagenet_ai_0419_biggan/train/ai/116_biggan_00107.png
plt.figure()
img = mpimg.imread('/content/sample_data/BigGAN/imagenet_ai_0419_biggan/train/ai/116_biggan_00094.png')
imgplot = plt.imshow(img)
plt.axis("off")
plt.show()

In [None]:
import gc

del model, inputs, labels, dataset, image_name
gc.collect()
torch.cuda.empty_cache()

In [None]:
torch.cuda.memory_summary(device=None, abbreviated=False)