# Scattering transform

Comparing the importance of different scales in a CNN with the importance with the Scattering Transform. Used on BDAPPV for now.
* implement a scattering transform

In [None]:
# Libraries

import numpy as np 
import matplotlib.pyplot as plt
from PIL import Image
import os
import json
import torch
import torchvision
import torch.nn as nn
from kymatio.torch import Scattering2D
from kymatio.numpy import Scattering2D as npScattering2D
import tqdm
from src.bdappv import BDAPPVClassification
from src.utils import confusion
from matplotlib import gridspec
import matplotlib as mpl
import matplotlib.cm as cm
from sklearn.linear_model import LinearRegression

In [None]:
def get_K(L,J,num_channels=3): # taken from kymatio
    """
    L:number of angles of the scattering transform
    J: number of scales
    num_channels: number of input channels
    """
    return int(1 + L*J + (L**2)*J*((J-1)/2))*num_channels    

class Scattering2dCNN(nn.Module):
    '''
        Simple CNN with 3x3 convs based on VGG
    '''
    def __init__(self, J, \
                 input_shape,\
                 classifier_type='linear',\
                 L=8, \
                 num_classes=2):
        
        super(Scattering2dCNN, self).__init__()
        self.in_channels = get_K(L,J)
        self.J=J
        self.input_shape=input_shape
        self.classifier_type = classifier_type
        self.num_classes=num_classes
        self.build()

    def build(self):
        cfg = [256, 256, 256, 'M', 512, 512, 512, 1024, 1024]
        layers = []
        self.K = self.in_channels
        self.out_shape=self.input_shape // 2**self.J

        self.bn = nn.BatchNorm2d(self.K)
        if self.classifier_type == 'cnn':
            for v in cfg:
                if v == 'M':
                    layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
                else:
                    conv2d = nn.Conv2d(self.in_channels, v, kernel_size=3, padding=1)
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                    self.in_channels = v

            layers += [nn.AdaptiveAvgPool2d(2)]
            self.features = nn.Sequential(*layers)
            self.classifier =  nn.Linear(1024*4, self.num_classes)

        elif self.classifier_type == 'mlp':
            self.classifier = nn.Sequential(
                        nn.Linear(self.K*self.out_shape*self.out_shape, 1024), nn.ReLU(),
                        nn.Linear(1024, 1024), nn.ReLU(),
                        nn.Linear(1024, 10))
            self.features = None

        elif self.classifier_type == 'linear':
            self.classifier = nn.Linear(self.K*self.out_shape*self.out_shape,self.num_classes)
            self.features = None


    def forward(self, x):
        x = self.bn(x.view(-1, self.K, self.out_shape, self.out_shape))
        if self.features:
            x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


In [None]:
images_list = json.load(open("data/images_lists.json"))

dataset_dir = "path/to/dataset"
models_dir="weights-scattering"
batch_size = 512
device="device"

# baseline transforms: no corruptions
BASELINE = torchvision.transforms.Compose([
    torchvision.transforms.ToPILImage(),
    torchvision.transforms.ToTensor()#,
    #torchvision.transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)),
])

datasets = {
    'google_test' : BDAPPVClassification(os.path.join(dataset_dir, "google"), size = 200, \
                                                    transform=BASELINE, images_list=images_list["test"], \
                                                        random = False, downsample=200),
    'ign_test'    : BDAPPVClassification(os.path.join(dataset_dir, "ign"), size = 200, \
                                                    transform=BASELINE, images_list=images_list["test"], \
                                                        random = False),
} 


## Quantitative evaluation 

We evaluate the trained models on Google and IGN

### Baseline model evaluation

In [None]:
def eval(model, device, test_loader, scattering):

    model.eval()
    tp, tn, fp, fn = 0, 0, 0, 0

    with torch.no_grad():
        for data, target, _ in tqdm.tqdm(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(scattering(data))
            preds=nn.functional.softmax(output, dim=1)
            preds=preds[:,1] # take the value predicted for the 2nd column
            pred = output.max(1, keepdim=True)[1].squeeze(1) # get the index of the max log-probability

            true_positives, false_positives, \
                true_negatives, false_negatives, _ = confusion(pred, target)
                        
            tp += true_positives
            tn += true_negatives
            fp += false_positives
            fn += false_negatives

    precision= tp / (tp+fp)
    recall=tp / (tp+fn)
    
    f1_score=2 * (precision * recall) / (precision + recall)
   
    return tp, tn, fp, fn, f1_score

In [None]:
eval_res={ # each count corresponds to a scale (J=1,2,3)
    1:{},
    2:{},
    3:{}
}
input_shape=200

for case in os.listdir(models_dir):

    if 'ign' in case:
        J=int(case.split('_')[2])
        scattering=Scattering2D(J=J, shape=(input_shape,input_shape)).to(device)

        model=torch.load(os.path.join('path/to/models'))
        model.eval()
        model.to(device)

        for case, key in zip(['Google', 'IGN'], datasets.keys()):

            eval_res[J][case]={}

            print('Evaluating case ...... {}'.format(case))
            dataset=torch.utils.data.DataLoader(datasets[key], batch_size=batch_size)

            tp, tn, fp, fn, f1_score=eval(model,device,dataset,scattering)

            eval_res[J][case]['confusion_matrix']=(tp,fp,tn,fn)
            eval_res[J][case]['f1_score']=f1_score

In [None]:
overhead="&&F1 Score & TP & TN & FP & FN"
print(overhead)

for J in eval_res.keys():
    for case in eval_res[J].keys():
        
        f1=eval_res[J][case]['f1_score']
        tp, fp, tn, fn=eval_res[J][case]['confusion_matrix']

        var='&{}&{:0.2f}&{}&{}&{}&{}'.format(
            case, f1,tp,tn,fp,fn
        )
        print(var)

In [None]:
def random_classifier_performance(positives, total):
    """
    returns the performance of a random classifier as a baseline
    """
    negatives=total-positives

    pos = positives/total
    neg = negatives/total
    # compute the confusion matrix for a model that randomly
    # predicts the class

    tp = total * pos * .5
    tn = total * neg * .5
    fp = total * neg * .5
    fn = total * pos * .5

    print(tp, tn, fp, fn)

    precision=tp/(tp+fp)
    recall=tp/(tp+fn)

    f1=2*(precision*recall)/(precision+recall)
    return precision, recall, f1

J=1
for case in eval_res[J].keys():
    tp, fp, tn, fn=eval_res[J][case]['confusion_matrix']
    positives=tp+fn
    total=tp+tn+fp+fn


    precision, recall, f1=random_classifier_performance(positives,total)

    line = '{} : Precision: {:0.2f}, Recall: {:0.2f}, F1 : {:0.2f} '.format(case,
                                                                            precision,
                                                                            recall,
                                                                            f1)
    print(line)