# The StereoScope
built for Hugging Face's AI or Not competition.<br>
https://huggingface.co/spaces/competitions/aiornot <br>

Copyright 2023, [Jozsef Szalma](https://www.linkedin.com/in/szalma/)<br>
Creative Commons Attribution-NonCommercial 4.0 International Public License <br>
https://creativecommons.org/licenses/by-nc/4.0/legalcode <br>
If you want to use the below code commercially then do contact me for licensing or with other offers of collaboration ;)<br>
Also keep in mind, the weights of pretrained ConvNeXt-V2 are also on CC BY-NC 4.0 as of writing.

HW requirements: I built this notebook on an nvidia GPU with 24GB memory

In [None]:
#important! The below code is relying on timm models not available in timm 0.6.x, I used timm 0.8.12.dev0 installed from git
%pip install git+https://github.com/rwightman/pytorch-image-models.git

In [None]:
import os
import numpy as np
import random
from random import randint
from tqdm import tqdm
from tqdm.auto import tqdm
import pandas as pd
import math

import torch
import timm
import torchmetrics
import torch.optim as optim
import torch.nn as nn

from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import autocast 
from torch.cuda.amp import GradScaler

from timm.models.layers import trunc_normal_

from huggingface_hub import login
from datasets import load_dataset

import mlflow

### Parameters<br>
credentials and working directories are set up in environment variables HF_KEY, WORKING_DIR_WIN and CHECKPOINT_DIR_WIN

In [2]:
# ConvNeXt
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the MIT license

# ConvNeXt-V2
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
# No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.


args = {
        'random_seed'           : 42,
        'rgb_target_size'       : (384,384), 
        'fft_target_size'       : (384,384),  
        'rgb_mean'              : (0.485, 0.456, 0.406), 
        'rgb_std'               : (0.229, 0.224, 0.225), 
        'rgb_model_name'        : 'convnextv2_huge.fcmae_ft_in22k_in1k_384',
        'fft_model_name'        : 'convnext_xlarge.fb_in22k_ft_in1k_384',
        'rgb_model_pretrained'  : True,
        'fft_model_pretrained'  : True, 
        'rgb_model_frozen'      : False,
        'fft_model_frozen'      : False, 

        'batch_size'            : 2, 
        'num_workers'           : 0, #multithreading is a tad problematic in Jupyter on Windows
        'num_epochs'            : 100, 
        'weight_decay'          : 1e-3,
        'resume_from_checkpoint': None, #'DualDetector_23a34a28b8024139bf238ce289204654_epoch=29_val_loss=1e-05.pth', 
        'validation_size'       : 0.1, 

        'model_default_lr'      : 4e-3 * 0.8,
        'model_default_batch'   : 4096,
        #LR is scaled from the model_default_lr with (batch_size / default_batch_size)

        'gradient_clip'         : None, 
        'gradient_accum'        : 4        

}

hf_key = os.getenv("HF_KEY")
working_dir = os.getenv("WORKING_DIR_WIN")
checkpoint_dir = os.getenv("CHECKPOINT_DIR_WIN")

torch.manual_seed(args['random_seed'])
np.random.seed(args['random_seed'])
random.seed(args['random_seed'])
np_generator = np.random.default_rng(args['random_seed'])

torch_device = 'cuda'

### Multi-Stream architecture
- the model is composed of two pretrained networks joined on the last fully connected layer
- the RGB network is exposed to random 384x384 crops 
- the FFT network is exposed to 2D spectrograms of the same cropped RGB image

In [None]:
class DualDetector (nn.Module):
    def __init__(self, rgb_model_name='convnext_xlarge_in22k',fft_model_name='convnext_xlarge_in22k', rgb_model_pretrained=True, fft_model_pretrained=True,rgb_model_frozen=True, fft_model_frozen=True): 
        super(DualDetector,self).__init__()
        self.model_rgb = timm.create_model(rgb_model_name, pretrained = rgb_model_pretrained)
        self.model_fft = timm.create_model(fft_model_name, pretrained = fft_model_pretrained)
        
        if rgb_model_frozen:
            for p in self.model_rgb.parameters():
                p.requires_grad = False
        if fft_model_frozen:
            for p in self.model_fft.parameters():
                p.requires_grad = False

        embedding_size = self.model_rgb.head.fc.in_features + self.model_fft.head.fc.in_features
        
        self.model_rgb.head.fc = nn.Identity()
        self.model_fft.head.fc = nn.Identity()

        self.head_fc1 = nn.Linear(embedding_size, 1)

        self.criterion = torch.nn.BCEWithLogitsLoss()
        self.val_accuracy = torchmetrics.Accuracy('binary')


    def forward(self, rgb, fft):
        
        rgb = self.model_rgb(rgb)
        fft = self.model_fft(fft)
        embeddings = torch.cat([rgb,fft],dim=1)
  
        x = self.head_fc1(embeddings)

        return x      

        

In [4]:

dualdetector = DualDetector(
                            rgb_model_name=args['rgb_model_name'],
                            fft_model_name=args['fft_model_name'],
                            rgb_model_pretrained=args['rgb_model_pretrained'], 
                            fft_model_pretrained=args['fft_model_pretrained'],
                            rgb_model_frozen=args['rgb_model_frozen'], 
                            fft_model_frozen=args['fft_model_frozen']
                            )

### dataset from Hugging Face<br>

In [None]:
login(token=hf_key,add_to_git_credential=True)

os.environ['HF_HOME']=working_dir
os.chdir(working_dir)

ds = load_dataset('competitions/aiornot')

split = ds["train"].train_test_split(args['validation_size'],generator=np_generator)
ds["train"] = split["train"]
ds["validation"] = split["test"]

### Preprocessing
I feed the network two images; a cropped RGB image and its 2D spectrogram.<br>
I intentionally avoid downsampling images.<br>
Normalizing is done with the mean and std specific to the pretrained network.<br>
Data augmentation is a trade-off between disrupting the fake image signals (repeating patterns, color-shifts, changes in sharpness) and having a large enough training set.<br>
At inference time the images are center cropped.

In [6]:
rgb_transform = transforms.Compose([                
                                        transforms.RandomVerticalFlip(),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
                                        transforms.RandomGrayscale(),
                                   
                                        transforms.Normalize(args['rgb_mean'], args['rgb_std'])
                                    ])


rgb_transform_val = transforms.Compose([
                               
                                        transforms.Normalize(args['rgb_mean'], args['rgb_std'])
                                    ])


def simple_crop_tensor(target_size, tensor_img, how='random'):
    #a simple cropping method of mine that A: does not need or do padding and B: can be either random or centered and C: won't do any rescaling
    width, height = tensor_img.shape[2], tensor_img.shape[1]
    target_width, target_height = target_size
    width_range = width - target_width
    height_range = height - target_height

    if ((width_range > 0) or (height_range > 0)) :
        if how == 'random':
            x_movement = randint(0, width_range-1)
            y_movement = randint(0, height_range-1)
        elif how == 'center':
            x_movement = int(width_range/2)-1
            y_movement = int(height_range/2)-1
        else:
            raise NotImplemented
            
        bbox = (y_movement, x_movement, y_movement + target_height, x_movement + target_width)

        cropped_tensor = tensor_img[:, bbox[0]:bbox[2], bbox[1]:bbox[3]]

        return cropped_tensor
    else:
        return tensor_img


def tensor_spectrogram(im):
    #generates a 2D spectrogram with Fast Fourier transform (phase discarded) for each channel of an RGB image

    for i in range(3):
        img = im[i,:,:]
        fft_img = torch.fft.fft2(img)
        fft_img = torch.log(torch.abs(fft_img) + 1e-3)
        
        fft_img_np = fft_img.cpu().numpy()
        fft_min = np.percentile(fft_img_np,5) + 1e-8 
        fft_max = np.percentile(fft_img_np,95) + 1e-7
        
        fft_img = (fft_img - fft_min)/(fft_max - fft_min)
        fft_img = (fft_img-0.5)*2

        fft_img = torch.clamp(fft_img, min=-1, max=1) 
        im[i,:,:] = fft_img

    return im          


The preprocessing code is bit of a spagetti and could be certainly improved.<br>
<br>
I'm off-loading preprocessing to the GPU at the earliest opportunity, as I run into a CPU bottleneck here.

In [7]:

def train_transform_fn(example_batch):

    tensor_lst_rgb = []
    tensor_lst_fft = []
    for im in example_batch['image']:
        
        im_tensor = transforms.functional.pil_to_tensor(im).to(torch_device).float() / 255.0
        
        im_tensor = rgb_transform(im_tensor)
        #TODO move this into a transforms.Lambda
        im_tensor_rgb = simple_crop_tensor(target_size=args['rgb_target_size'], tensor_img=im_tensor, how='random')
        im_tensor_fft = tensor_spectrogram(im=im_tensor_rgb.detach().clone()) 
        
        tensor_lst_fft.append(im_tensor_fft)
        tensor_lst_rgb.append(im_tensor_rgb)
    x_rgb = torch.stack(tensor_lst_rgb)
    x_fft = torch.stack(tensor_lst_fft)

    
    y = example_batch['label']
    return {'x_rgb': x_rgb,'x_fft': x_fft, 'y': y}
    

def val_transform_fn(example_batch):
    
    tensor_lst_rgb = []
    tensor_lst_fft = []
    for im in example_batch['image']:

        im_tensor = transforms.functional.pil_to_tensor(im).to(torch_device).float() / 255.0
        
        im_tensor = rgb_transform(im_tensor)
     
        im_tensor_rgb = simple_crop_tensor(target_size=args['rgb_target_size'], tensor_img=im_tensor, how='center')
        im_tensor_fft = tensor_spectrogram(im=im_tensor_rgb.detach().clone()) 
        
        tensor_lst_fft.append(im_tensor_fft)
        tensor_lst_rgb.append(im_tensor_rgb)
    x_rgb = torch.stack(tensor_lst_rgb)
    x_fft = torch.stack(tensor_lst_fft)

    y = example_batch['label']
    return {'x_rgb': x_rgb,'x_fft': x_fft, 'y': y}
    


def test_transform_fn(example_batch):
    
    tensor_lst_rgb = []
    tensor_lst_fft = []
    for im in example_batch['image']:

        im_tensor = transforms.functional.pil_to_tensor(im).to(torch_device).float() / 255.0
        
        im_tensor = rgb_transform(im_tensor)
      
        im_tensor_rgb = simple_crop_tensor(target_size=args['rgb_target_size'], tensor_img=im_tensor, how='center')
        im_tensor_fft = tensor_spectrogram(im=im_tensor_rgb.detach().clone()) 
        
        tensor_lst_fft.append(im_tensor_fft)
        tensor_lst_rgb.append(im_tensor_rgb)
    x_rgb = torch.stack(tensor_lst_rgb)
    x_fft = torch.stack(tensor_lst_fft)

    return {'x_rgb': x_rgb,'x_fft': x_fft}
    


def collate_fn(batch):
    x_rgb = torch.stack([ex['x_rgb'] for ex in batch])
    x_fft = torch.stack([ex['x_fft'] for ex in batch])
    y = torch.tensor([ex['y'] for ex in batch]).float().to(torch_device)
    return x_rgb,x_fft, y
    

def collate_fn_test(batch):
    x_rgb = torch.stack([ex['x_rgb'] for ex in batch])
    x_fft = torch.stack([ex['x_fft'] for ex in batch])

    return x_rgb,x_fft

### MLFlow 
I'm an avid practitioner of "lossing" and I can't always get Tensorboard to show me data during training and that makes me sad.<br>
https://twitter.com/JozsefSzalma/status/1621060399956213761?s=20&t=3xoxeOdMjZTM1kdGzvLADw

In [8]:
try: 
    experiment_id = mlflow.create_experiment(working_dir)
except:
    experiment_id = mlflow.get_experiment_by_name(working_dir)

experiment = mlflow.set_experiment(working_dir)

mlflow.pytorch.autolog()
mlflow.start_run()
run = mlflow.active_run()
run_id = run.info.run_id

### Training Loop<br>

a couple of design choices:
- data loaders pin_memory=False: the data is already on the GPU so this won't work otherwise
- the training loop is a bit artisanal on purpose; my apologies to all the Pytorch Lightning fans.
- mixed precision to fit the architecture into the GPU and to speed up the training (GradScaler and autocast)
- a bit of gradient accumulation as I don't have much trust in batch size of 2 

In [None]:

#need to scale learning rate due to small batch sizes
learning_rate = args['model_default_lr'] * (args['batch_size'] / args['model_default_batch'])
args['learning_rate'] = learning_rate

train_loader = DataLoader(ds['train'].with_transform(train_transform_fn), batch_size=args['batch_size'], shuffle=True, collate_fn=collate_fn, num_workers=args['num_workers'], pin_memory=False,drop_last=True)
val_loader = DataLoader(ds['validation'].with_transform(val_transform_fn), batch_size=args['batch_size'], shuffle=False, collate_fn=collate_fn, num_workers=args['num_workers'], pin_memory=False)

optimizer = optim.AdamW(dualdetector.parameters(), lr=learning_rate,weight_decay=args['weight_decay'])
scheduler = CosineAnnealingLR(optimizer, T_max=args['num_epochs'], eta_min=0, last_epoch=-1)

#amp is all you need
scaler = GradScaler()

start_epoch = 0

if (args['resume_from_checkpoint'] is not None):
    checkpoint = torch.load(checkpoint_dir + args['resume_from_checkpoint'],map_location='cpu') #load checkpoint to CPU first and only load the model to GPU once the state dicionary is restored
    dualdetector.load_state_dict(checkpoint['model_state_dict'])
    dualdetector.to(torch_device)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict']) #the model should be on GPU already when the optimizer and the scheduler are built
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    torch.set_rng_state(checkpoint['torch_rng_state']) 
    start_epoch = checkpoint['epoch']+1
    print('resuming from epoch ', start_epoch)
else:
    dualdetector.to(torch_device)

args['start_epoch'] = start_epoch
torch.set_float32_matmul_precision('medium') #RTX3090 specific optimization; 

mlflow.log_params(args)


for epoch in range(start_epoch,args['num_epochs']):
    running_loss = 0.0
    dualdetector.train()
    with torch.set_grad_enabled(True):
        with tqdm(total=len(train_loader)) as pbar:
            for i, (x_rgb,x_fft, labels) in enumerate(train_loader):
                
                #using autocast + scaler for mixed precision 
                with autocast(device_type=torch_device, dtype=torch.float16):

                    outputs = dualdetector(x_rgb,x_fft).squeeze(1)
                    
                    loss = dualdetector.criterion(outputs, labels)
                
                scaler.scale(loss).backward()
                
                if (i+1) % args['gradient_accum'] == 0:
                    
                    if (args['gradient_clip'] is not None):
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(dualdetector.parameters(), max_norm=args['gradient_clip'])
                    
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

                running_loss +=loss.item()
                pbar.set_description(f"running_loss:{(running_loss/(i+1)):.6f}")
                pbar.update(1)
            
        current_lr = optimizer.param_groups[0]['lr']
        scheduler.step()

        #eval step
        dualdetector.eval()
        with torch.no_grad():
            running_val_loss = 0.0        
            for x_rgb,x_fft, labels in tqdm(val_loader):
                with autocast(device_type=torch_device, dtype=torch.float16):

                    outputs = dualdetector(x_rgb,x_fft).squeeze(1)
                    
                    val_loss = dualdetector.criterion(outputs, labels)
                running_val_loss += val_loss
                dualdetector.val_accuracy.update(outputs.sigmoid(), labels)

    
    #loggig
    running_loss = running_loss/len(train_loader)
    running_val_loss = running_val_loss/len(val_loader)
    val_accuracy = dualdetector.val_accuracy.compute()
    print(f"Epoch: {epoch}/{args['num_epochs']}, Loss: {running_loss}, Validation Loss: {running_val_loss}, Validation Accuracy: {val_accuracy}, Learning Rate: {current_lr} ")

    mlflow.log_metric('lr',current_lr,step=epoch) 
    mlflow.log_metric('loss',running_loss,step=epoch) 
    mlflow.log_metric('val_loss',running_val_loss,step=epoch) 
    mlflow.log_metric('val_accuracy',val_accuracy,step=epoch) 

    #checkpointing
    checkpoint_path = checkpoint_dir + f"DualDetector_{run_id}_epoch={epoch}_val_loss={round(float(running_val_loss),5)}.pth"
    torch.save({
                'epoch': epoch,
                'model_state_dict': dualdetector.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'torch_rng_state': torch.get_rng_state(),
                'loss': loss,
                'batch_size': args['batch_size'],
                'learning_rate': learning_rate,
                'current_lr': current_lr,
                'weight_decay': args['weight_decay']
                }, 
                checkpoint_path)

    


In [None]:
mlflow.end_run()

### Inference

In [8]:

checkpoint_name = 'DualDetector_23a34a28b8024139bf238ce289204654_epoch=29_val_loss=1e-05.pth'

dualdetector = DualDetector(
                            rgb_model_name=args['rgb_model_name'],
                            fft_model_name=args['fft_model_name'],
                            rgb_model_pretrained=args['rgb_model_pretrained'], 
                            fft_model_pretrained=args['fft_model_pretrained'],
                            rgb_model_frozen=args['rgb_model_frozen'], 
                            fft_model_frozen=args['fft_model_frozen']                            
                            )

checkpoint = torch.load(checkpoint_dir + checkpoint_name)
dualdetector.load_state_dict(checkpoint['model_state_dict'])

dualdetector.to(torch_device)
dualdetector.eval()
torch.set_float32_matmul_precision('medium') #RTX3090 specific optimization; 

test_loader = DataLoader(ds['test'].with_transform(test_transform_fn), batch_size=16, shuffle=False, collate_fn=collate_fn_test, num_workers=0, pin_memory=False)

In [None]:

all_preds_sigm = []

with torch.no_grad():
    with autocast(device_type=torch_device, dtype=torch.float16):
        for x_rgb,x_fft in tqdm(test_loader):
            x_rgb = x_rgb.to(torch_device)
            x_fft = x_fft.to(torch_device)
            out = dualdetector(x_rgb,x_fft)
            all_preds_sigm.extend(out.squeeze().sigmoid().detach().cpu().numpy().tolist())
           



df_sigm = pd.DataFrame(ds['test'].remove_columns(['image']))     
     
df_sigm['label'] = all_preds_sigm


In [None]:
df_sigm.head()

In [11]:
df_sigm.to_csv(checkpoint_name + '.csv', index=None)
