<a href="https://colab.research.google.com/github/likithpala7/real-fake-detector/blob/main/train_mc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install peft

In [2]:
from google.colab import drive
from google.colab import files
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
from torch.utils.data import Dataset
import json
import torch
import clip
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import json
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np
import pandas as pd
import time
from torch.utils.data import DataLoader
from torchvision.ops import sigmoid_focal_loss
import torch
import matplotlib.pyplot as plt
from collections import namedtuple
import pandas as pd
from copy import deepcopy
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
!unzip -q drive/MyDrive/GenImage/imagenet_ai_holdout.zip

In [6]:
!rm imagenet_ai_holdout/ADM/train/115_adm_156.PNG
!rm imagenet_ai_holdout/BigGAN/train/116_biggan_00098.png
!rm imagenet_ai_holdout/BigGAN/train/116_biggan_00107.png

In [2]:
label_dict = {model: i for i, model in
               enumerate(sorted(os.listdir('imagenet_ai_holdout')))}

class ImageDataset(Dataset):

    def get_dataset(self, d_type):

        data = []

        for dataset in os.listdir('imagenet_ai_holdout'):

            imgs = os.listdir(os.path.join('imagenet_ai_holdout', dataset, d_type))

            data.extend([os.path.join('imagenet_ai_holdout', dataset, d_type, img)
                         for img in imgs])


        return data


    def __init__(self, d_type='train'):


       self.data = self.get_dataset(d_type)

       _, self.preprocess = clip.load('ViT-B/32', device=device)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        # Get the question and answer at the idx
        img_path = self.data[idx]

        model_type = img_path.split(os.path.sep)[1]
        label = torch.tensor(label_dict[model_type]).to(device)

        # Concatenate images into a single tensor
        img = self.preprocess(Image.open(img_path)).to(device)

        return img, label

In [3]:
CLIP_HIDDEN_STATE = 512

def print_trainable_parameters(model):

    """
    Prints the number of trainable parameters in the model
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()

    print(
        f"Trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


class CLIPModel(nn.Module):

    def __init__(self):

        super().__init__()

        self.clip_model, _ = clip.load('ViT-B/32', device=device)

        # Freeze the CLIP model
        for param in self.clip_model.parameters():
            param.requires_grad = False

        # Classification head
        self.fc = nn.Sequential(
            nn.Linear(CLIP_HIDDEN_STATE, CLIP_HIDDEN_STATE//2),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(CLIP_HIDDEN_STATE//2, len(label_dict)),
        )


    def forward(self, x):

        x = self.clip_model.encode_image(x).float()
        x = self.fc(x)

        return x

In [4]:
Config = namedtuple('Instance', ['batch_size', 'learning_rate',
                                 'weight_decay', 'num_workers',
                                 'epochs', 'load_checkpoint',
                                 'file_checkpoint'])

config = Config(
    batch_size = 512,
    learning_rate = 1e-3,
    weight_decay = 0.05,
    num_workers = 0,
    epochs = 5,
    load_checkpoint = False,
    file_checkpoint = '',
)

In [None]:
def save_model(model, model_name):
    # Save the model into the designated folder
    path = os.path.join('drive', 'MyDrive', 'GenImage', 'results', timestr, model_name + '.pth')
    torch.save(model, path)


def val_model(dloader, val_model):
    val_model.eval()
    val_loss = 0
    val_accuracy = 0
    predictions, label_list = [], []

    criterion = torch.nn.CrossEntropyLoss()

    with torch.no_grad():

      for idx, (inputs, labels) in tqdm(enumerate(dloader), total=len(dloader)):

          logits = val_model(inputs)

          loss = criterion(logits, labels)

          val_loss = loss.item()

          val_acc, preds = compute_accuracy(logits, labels)
          val_accuracy += val_acc
          predictions.extend(preds.tolist())
          label_list.extend(labels.tolist())

    plot_confusion_matrix(predictions, label_list)
    return val_loss / len(val_dataloader), val_accuracy / len(val_dataloader)


def save_stats(train_loss, val_loss, train_acc, val_acc, epochs,
               lr, train_accs, val_accs):
    stats_dict = {
        'losses': losses,
        'val losses': val_losses,
        'training accuracies': train_accs,
        'val accuracies': val_accs,
        'min train loss': train_loss,
        'min val loss': val_loss,
        'max train acc': train_acc,
        'max val acc': val_acc,
        'epochs': epochs,
        'learning rate': lr,
    }

    fname = f'stats.json'

    # Save stats into checkpoint
    with open(os.path.join('drive', 'MyDrive', 'GenImage', 'results', timestr, fname), 'w') as f:
        json.dump(stats_dict, f)

def compute_accuracy(logits, labels):

    preds = torch.argmax(F.softmax(logits, dim=-1), dim=-1)
    correct = (preds == labels).sum().item()
    total = labels.size(0)
    return correct / total, preds

def plot_confusion_matrix(predictions, labels):

    # Define the class names and order
    classes = sorted(os.listdir('imagenet_ai_holdout'))

    # Create the confusion matrix using sklearn
    cm = confusion_matrix(labels, predictions)

    # Calculate row-wise sums to normalize the confusion matrix
    row_sums = cm.sum(axis=1, keepdims=True)
    normalized_cm = cm / row_sums.astype(float)

    # Convert normalized confusion matrix to DataFrame with named rows and columns
    normalized_cm_df = pd.DataFrame(normalized_cm, index=classes, columns=classes)

    plt.figure(figsize=(8, 6))
    sns.set(font_scale=1.4)  # Adjust to fit labels properly

    # Create a heatmap plot
    sns.heatmap(normalized_cm_df, annot=True, fmt='.2f', cmap='Blues', cbar=False,
                annot_kws={"size": 16}, linewidths=1, linecolor='black')

    plt.title('Normalized Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.yticks(rotation=0)  # Ensure correct orientation of row labels

    # Adjust layout to ensure all borders are visible
    plt.tight_layout(pad=1.0)

    img_num = len(os.listdir(os.path.join('drive', 'MyDrive', 'GenImage',
                                          'results', timestr, 'Confusion Matrices')))
    fname = f'cf_{img_num}.png'
    plt.savefig(os.path.join('drive', 'MyDrive', 'GenImage', 'results', timestr, 'Confusion Matrices', fname))


def plot_loss(training_loss, val_loss):
    num_epochs = len(training_loss)

    plt.clf()
    plt.plot(range(1, num_epochs + 1), training_loss, label='Training Loss')
    plt.plot(range(1, num_epochs + 1), val_loss, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Num epochs')
    plt.ylabel('Loss')
    plt.legend()
    fname = f'loss.png'
    plt.savefig(os.path.join('drive', 'MyDrive', 'GenImage', 'results', timestr, fname))


def train(train_loss, val_loss, train_acc, val_acc, best_model, epochs, learning_rate, train_accs, val_accs):

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1, verbose=False)

    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(epochs, config.epochs):
        print('-------------------- EPOCH ' + str(epoch) + ' ---------------------')
        model.train()
        epoch_loss = 0
        epoch_accuracy = 0

        for step, (inputs, labels) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):

            logits = model(inputs)

            loss = criterion(logits, labels)

            epoch_loss += loss.item()
            acc, _ = compute_accuracy(logits, labels)
            epoch_accuracy += acc

            # Back-propogate
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        # Get train and val loss per batch
        epoch_train_loss = epoch_loss / len(train_dataloader)
        epoch_train_accuracy = epoch_accuracy / len(train_dataloader)
        losses.append(epoch_train_loss)

        epoch_val_loss, epoch_val_accuracy = val_model(val_dataloader, model)
        val_losses.append(epoch_val_loss)

        train_accs.append(epoch_train_accuracy)
        val_accs.append(epoch_val_accuracy)

        if not val_loss or min(epoch_val_loss, val_loss) == epoch_val_loss:
            val_loss = epoch_val_loss
            best_model = deepcopy(model.state_dict())
        if not train_loss or min(train_loss, epoch_train_loss) == epoch_train_loss:
            train_loss = epoch_train_loss
        if not train_acc or max(train_acc, epoch_train_accuracy) == epoch_train_accuracy:
            train_acc = epoch_train_accuracy
        if not val_acc or max(val_acc, epoch_val_accuracy) == epoch_val_accuracy:
            val_acc = epoch_val_accuracy

        # Adjust learning rate scheduler
        scheduler.step()

        print('Training Loss: ' + str(epoch_train_loss))
        print('Validation Loss: ' + str(epoch_val_loss))
        print('Training Accuracy: ' + str(epoch_train_accuracy))
        print('Validation Accuracy: ' + str(epoch_val_accuracy))
        print('---------------------------------------------')

        # Save model and stats for checkpoints
        save_model(best_model, 'latest_model')
        epochs += 1
        save_stats(train_loss, val_loss,train_acc, val_acc, epochs,
                   scheduler.get_last_lr()[0], train_accs, val_accs)

    # Save the model and plot the loss
    plot_loss(losses, val_losses)
    return train_loss, val_loss, train_acc, val_acc

def save_experiment(statistics):
    """
    Saves the experiment results to a csv
    :param config: The hyperparameters used
    :param statistics: The accuracies for the training, validation, and test sets
    """
    trial_dict = {
        'Model name': [timestr],
        'Learning rate': [config.learning_rate],
        'Weight decay': [config.weight_decay],
        'Batch size': [config.batch_size],
        'Epochs': [config.epochs],
        'Min Training Loss': [statistics[0]],
        'Min Validation Loss': [statistics[1]],
        'Maximum Training Accuracy': [statistics[2]],
        'Maximum Validation Accuracy': [statistics[3]],
    }

    trial_dict = pd.DataFrame(trial_dict)
    trial_dict.to_csv(os.path.join('drive', 'MyDrive', 'GenImage', 'results',
                                   timestr, 'results.csv'), index=False, header=True)


if __name__ == '__main__':

    timestr = time.strftime("%Y%m%d-%H%M%S")

    checkpoint_path = os.path.join('drive', 'MyDrive', 'GenImage', 'results', timestr)
    print(f'All model checkpoints and training stats will be saved in {checkpoint_path}')


    losses = []
    val_losses = []
    train_accs, val_accs = [], []
    min_train_loss = None
    min_val_loss = None
    max_val_acc = None
    max_train_acc = None
    best_model = None
    epochs_ran = 0


    # Load processors and models
    model = CLIPModel()
    model.to(device)

    # Load datasets
    train_dset = ImageDataset()
    val_dset = ImageDataset(
        d_type = 'val'
    )

    # Create Dataloaders
    train_dataloader = DataLoader(train_dset, shuffle=True, batch_size=config.batch_size)
    val_dataloader = DataLoader(val_dset, shuffle=True, batch_size=config.batch_size,
                                num_workers=config.num_workers)

    # Load checkpoint if neccesary:
    if config.load_checkpoint:

        print('Loading model from ' + config.checkpoint_file)

        # Load the model and stats from the checkpoint
        model.load_state_dict(torch.load(os.path.join('drive', 'MyDrive', 'GenImage', 'results', config.checkpoint_file,
                                                      'latest_model.pth')))
        best_model = CLIPModel()
        best_model.load_state_dict(torch.load(os.path.join('drive', 'MyDrive', 'GenImage', 'results', config.checkpoint_file,
                                                          'latest_model.pth')))

        with open(os.path.join('drive', 'MyDrive', 'GenImage', 'results', config.checkpoint_file, 'stats.json'), 'r') as f:
            stats = json.load(f)

        min_train_loss, min_val_loss, losses, val_losses, epochs_ran = stats['min train loss'], stats[
            'min val loss'], stats['losses'], stats['val losses'], stats['epochs']
        max_train_acc, max_val_acc = stats['max train acc'], stats['max val acc']
        train_accs, val_accs = stats['training accuracies'], stats['val accuracies']

        print(f'Minimum Training Loss: {min_train_loss}')
        print(f'Training Losses: {losses}')
        print(f'Minimum Validation Loss: {min_val_loss}')
        print(f'Validation Losses: {val_losses}')
        print(f'Training Accuracies: {train_accs}')
        print(f'Validation Accuracies: {val_accs}')
        print(f'Maximum Training Accuracy: {max_train_acc}')
        print(f'Maximum Validation Accuracy: {max_val_acc}')
        print(f'Epochs ran: {epochs_ran}')
        timestr = config.checkpoint_file
    else:
        os.mkdir(os.path.join('drive', 'MyDrive', 'GenImage', 'results', timestr))
        os.mkdir(os.path.join('drive', 'MyDrive', 'GenImage', 'results', timestr, 'Confusion Matrices'))

    # If loading a checkpoint, use the learning rate from the last epoch
    if config.load_checkpoint:
        lr = stats['learning rate']
    else:
        lr = config.learning_rate

    min_train_loss, min_val_loss, max_train_acc, max_val_acc = (
        train(min_train_loss, min_val_loss, max_train_acc,
              max_val_acc, best_model, epochs_ran, lr, train_accs, val_accs))
    statistics = [min_train_loss, min_val_loss, max_train_acc, max_val_acc]
    save_experiment(statistics)


All model checkpoints and training stats will be saved in drive/MyDrive/GenImage/results/20240504-214639




-------------------- EPOCH 0 ---------------------


  7%|▋         | 35/527 [03:51<53:18,  6.50s/it]

In [None]:
from google.colab import runtime
runtime.unassign()