In [None]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import seaborn as sns
import torch.nn.functional as F
from torchsummary import summary
from glob import glob
from PIL import Image

In [None]:
# Choose which GPU to use
gpu_no = 0
torch.cuda.set_device(gpu_no)
use_gpu = torch.cuda.is_available()

In [None]:
#clean dataset (optional)
for img in glob("/home/jovyan/EJ/causality/chexnet/database/images/*png"):
    try:
        im = Image.open(img)
    except OSError as e:
        os.remove(img)

In [None]:
# default database path
data_dir = 'database/'
trans = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
image_datasets = datasets.ImageFolder(data_dir,trans)
dataloders = torch.utils.data.DataLoader(image_datasets, batch_size=8,
                                             shuffle=True)
#check database sizes
dataset_sizes = len(image_datasets)

# LOAD MODEL

In [None]:
import warnings
warnings.filterwarnings('ignore')
densenet121 = models.densenet121(pretrained=False)
kernelCount = densenet121.classifier.in_features
densenet121.classifier = nn.Sequential(nn.Linear(kernelCount, 14), nn.Softmax())
if use_gpu:                 
    densenet121 = densenet121.cuda()       
densenet121 = densenet121.float()

In [None]:
checkpoint = torch.load("models/m-25012018-123527.pth.tar")
from collections import OrderedDict
new_state_dict = OrderedDict()
checkpoint["state_dict"]
for k, v in checkpoint["state_dict"].items():
    name = k[19:] # remove module.`
    name = name.replace("norm.1","norm1")
    name = name.replace("norm.2","norm2")
    name = name.replace("conv.1","conv1")
    name = name.replace("conv.2","conv2")
    new_state_dict[name] = v
densenet121.load_state_dict(new_state_dict)


In [None]:
if use_gpu:             
    densenet121 = densenet121.cuda()       
densenet121 = densenet121.float()
summary(densenet121,(3,224,224))

In [None]:
freeze_layer = 6
f_model = list(densenet121.features.children())[:freeze_layer]
f_model = nn.Sequential(*f_model)
r_model = list(densenet121.features.children())[freeze_layer:len(list(densenet121.features.children()))]
r_model = nn.Sequential(*r_model)
cls_model = nn.Sequential(*list(densenet121.classifier.children()))
r_freeze = list(densenet121.features.children())[freeze_layer:len(list(densenet121.features.children()))]
r_freeze = nn.Sequential(*r_freeze)
cls_freeze = nn.Sequential(*list(densenet121.classifier.children()))

for param in r_freeze.parameters():
    param.requires_grad = False
for param in cls_freeze.parameters():
    param.requires_grad = False
print(f_model)
print(cls_model)
# densenet121

In [None]:
ae1 = nn.Conv2d(in_channels=128,out_channels=32,kernel_size=(3,3),padding=(1, 1))
hd1 = nn.Conv2d(in_channels=32,out_channels=16,kernel_size=(3,3),padding=(1, 1))
hd2 = nn.Conv2d(in_channels=16,out_channels=32,kernel_size=(3,3),padding=(1, 1))
ae2 = nn.Conv2d(in_channels=32,out_channels=128,kernel_size=(3,3),padding=(1, 1))

In [None]:
class dense_auto(nn.Module):
    def __init__(self):
        super(dense_auto, self).__init__()
        self.f_model = f_model
        self.ae1 = ae1
        self.hd1 = hd1
        self.hd2 = hd2
        self.ae1 = ae1
        self.ae2 = ae2
        self.r_model = r_model
        self.r_freeze = r_freeze
        self.r_model = r_model
        self.cls_model = cls_model
        self.cls_freeze = cls_freeze
        self.zero_out = np.random.randint(0,16)
        self.prob_zero = np.random.uniform(0,1,1)
        
    def forward_interpret(self,x):
        x = hd1(x)
        x[:,5,:,:] = 0
        x = hd2(x)
        return x
    
    def forward_shallow(self,x):
        x = self.ae1(x)
        y1 = x.clone()
        x = self.forward_interpret(x)
        y2 = x.clone()
        x = self.ae2(x)
        return x, y1,y2
    
    def forward(self, x):
        x = self.f_model(x)
        x2 = x.clone()
        y, y1,y2 = self.forward_shallow(x)
        y3 = y.clone()
        features = self.r_model(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1)
        x = self.cls_model(out)
        features_freeze = r_freeze(y)
        out_f = F.relu(features_freeze, inplace=True)
        out_f = F.adaptive_avg_pool2d(out_f, (1, 1)).view(features_freeze.size(0), -1)
        y = cls_freeze(out_f)
        return x, x2, y, y1, y2, y3


In [None]:
def tvloss(yhat, y):
    bsize, chan, height, width = y.size()
    errors = []
    dy = torch.abs(y[:,:,1:,:] - y[:,:,:-1,:])
    dyhat = torch.abs(yhat[:,:,1:,:] - yhat[:,:,:-1,:])
    error = torch.norm(dy - dyhat, 1)
    return error / height

In [None]:
from tqdm import tqdm
save_path = "causal_dense_zeroout_last2conv.pth"
def train_causal_model(c_model, optimizer, num_epochs=10):
    best_model_wts = c_model.state_dict()
    for epoch in range(num_epochs):
        c_model.train(True)  # Set model to training mode
        
        running_loss = 0.0
        test_loss = 10

        for data in tqdm(dataloders):
            inputs,labels = data
            if use_gpu:
                inputs = Variable(inputs.cuda())
            else:
                inputs = Variable(inputs)

            optimizer.zero_grad()
            
            x,x2,y,y1,y2,y3 = c_model(inputs)
#             loss_interpret = nn.CrossEntropyLoss()(y1,y2.detach().long())
            loss_interpret = tvloss(y1,y2)
            loss_ae = nn.L1Loss()(x2,y3.detach())
            loss_kl = nn.KLDivLoss()(F.log_softmax(y,-1), x.detach())

            loss = 5*loss_interpret + 2*loss_ae + 3*loss_interpret
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        epoch_loss = running_loss / dataset_sizes
#         if epoch % 100 == 0:
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('Loss: {:.10f} '.format(epoch_loss))

        # deep copy the model
        if test_loss > epoch_loss:
            test_loss = epoch_loss
            best_model_wts = c_model.state_dict()
            state = {'model':c_model.state_dict(),'optim':optimizer.state_dict()}
            torch.save(state,save_path)

    # load best model weights
    c_model.load_state_dict(best_model_wts)
    return c_model

In [None]:
causal_model = dense_auto().cuda()
optimizer_c = optim.Adam(filter(lambda p: p.requires_grad,causal_model.parameters()), lr=0.001)
for name, param in causal_model.named_parameters():
#     if "r_model" in name:
#         param.requires_grad = True
    if param.requires_grad:
        print(name)

In [None]:
import warnings
warnings.filterwarnings('ignore')
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
train_causal_model(causal_model,optimizer_c,num_epochs=1)

In [None]:
causal_eval = dense_auto().cuda()
checkpoint = torch.load("causal_dense_zeroout.pth")
causal_eval.load_state_dict(checkpoint["model"])

In [None]:
from PIL import Image
def computingECE(causal_model,image):
    img = Image.open(image).convert('RGB')
    trans = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    img = trans(img)
    img.unsqueeze_(0)
    img = img.cuda()
    outputs = causal_model(img)
    score = outputs[0][0]
    print(torch.max(score,0))
#     print(score)
    
    causal_score = outputs[2][0]
    print(torch.max(causal_score,0))
    causal_effect = causal_score - score
#     print(causal_effect)
    ece = torch.dot(causal_effect,score)
    return ece
image_path = "database/original_images/00000001_000.png"
ece = computingECE(causal_model,image_path)
ece

## OVERLAP

In [None]:
from PIL import Image
import cv2
import matplotlib.cm as cm
def computingECE(causal_model,image):
    img = cv2.imread(image,1)
#     display = img
    dx = img.shape[0] / 10
    dy = img.shape[1] / 10
    blocks = [12,22,23,32,33]
    for channel in range(img.shape[2]):
        for block in blocks:
            img[int((block-1)*dx):int((block)*dx),int((block-1)*dy):int((block)*dy)] = 255
    display = img
    img = Image.fromarray(img)
    trans = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    img = trans(img)
    img.unsqueeze_(0)
    img = img.cuda()
    outputs = causal_model(img)
    score = F.softmax(outputs[0],1).view(-1)
#     print(torch.max(score,0))
#     print(score)
    
    causal_score = F.softmax(outputs[2],1).view(-1)
#     print(torch.max(causal_score,0))
    causal_effect = causal_score - score
#     print(causal_effect)
    ece = torch.dot(causal_effect,score)
    return ece,display

In [None]:
image_path = "/home/jovyan/EJ/causality/CheXNet-Keras/data/masked_data/bbox/00007830_013.png"
ece,display = computingECE(causal_model,image_path)
plt.imshow(display)
ece

## FOOLBOX

In [None]:
from PIL import Image
import cv2

img_fool = Image.open(image_path).convert("RGB")
trans = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
img_fool = trans(img_fool)
img_fool.cuda()
print("img to cuda")
print(torch.max(img_fool))

import foolbox
from foolbox.models import PyTorchModel
vgg_fool = models.vgg16(pretrained=True).eval()
vgg_fool.cuda()

fool_model = foolbox.models.PyTorchModel(vgg_fool, bounds=(-3,3), num_classes=1000)

img_fool = img_fool.cpu().numpy()
print(img_fool.shape)
print(np.argmax(fool_model.predictions(img_fool)))
label = np.argmax(fool_model.predictions(img_fool))

# apply attack on source image
attack = foolbox.attacks.FGSM(fool_model)
adversarial = attack(img_fool, label)
print('adversarial class', np.argmax(fool_model.predictions(adversarial)))