In [1]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms
from torch.autograd import Variable

import os
import pandas as pd
import argparse
import nibabel as nib
import skimage.transform as skTrans
import numpy as np

from IPython import display

#load data
import dataloader
from dataloader import PerturbedDataloader

comp = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

pdl = PerturbedDataloader('./poisonedDataset', transform=comp, csv_file='./labels.csv')

image, label = pdl[1]
image2, label2 = pdl[628]
print(f"jpg: {image.shape}, {label}")
print(f"pt: {image2.shape}, {label2}")

jpg: torch.Size([3, 224, 224]), 0.0
pt: torch.Size([3, 224, 224]), 1.0


  image = torch.load(file_path).squeeze(0)


In [2]:
import classifier
from classifier import DorPatchClassifier

def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('BatchNorm') != -1:
        m.weight.data.normal_(0.00, 0.02)
        
#set up classifier
batch_size = 32
classifier = DorPatchClassifier()
classifier.apply(init_weights)

#load model if possible
loadModel = False
model_path = './DorPatchClassifier.pth'
if os.path.exists(model_path) and loadModel:
    classifier.load_state_dict(torch.load(model_path))

if torch.cuda.is_available():
    classifier = classifier.cuda()
    
optimizer = Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))

In [3]:
#training function
loss = nn.BCELoss()

def train_classifier(optimizer, results, real_data, labels):
    optimizer.zero_grad()
    labels = labels.unsqueeze(1)
    error = loss(results, labels)
    error.backward()
    optimizer.step()
    
    prediction = classifier(real_data)
    return error, prediction

In [4]:
#training loop
# 80% for training 20% for testing
train_size = int(0.8 * len(pdl))  
test_size = len(pdl) - train_size  
train_dataset, test_dataset = random_split(pdl, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

num_batches = len(train_loader)
num_epochs = 30

for epoch in range(num_epochs):
    for n_batch, (batch, label) in enumerate(train_loader):
        real_data = Variable(batch)
        result = classifier(real_data)
        print("CLASSIFIER FORWARDED")
        err, prediction = train_classifier(optimizer, result, real_data, label)
        result = result.detach()
        #show progress
        print(f"iter: {n_batch}/{num_batches} of epoch {epoch}/{num_epochs}")
        print(f"err: {err:.4f}")
        print(f"pred, gt:")
        for i in range(len(prediction)):
            print(f"\t{prediction[i].item():.4f}, {label[i]%10}")
        display.clear_output(True)
torch.save(classifier.state_dict(), model_path)

CLASSIFIER FORWARDED
iter: 15/16 of epoch 29/30
err: 0.0000
pred, gt:
	0.0000, 0.0
	1.0000, 1.0
	1.0000, 1.0
	1.0000, 1.0
	1.0000, 1.0
	0.0000, 0.0
	0.0000, 0.0
	1.0000, 1.0
	0.0000, 0.0
	0.0000, 0.0
	1.0000, 1.0
	0.0000, 0.0
	1.0000, 1.0
	1.0000, 1.0
	1.0000, 1.0
	1.0000, 1.0
	1.0000, 1.0
	1.0000, 1.0
	1.0000, 1.0
	1.0000, 1.0
	1.0000, 1.0
	1.0000, 1.0
	1.0000, 1.0


In [8]:
import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix

total_loss = 0.0

mispredicts = 0

for n_batch, (batch, label) in enumerate(test_loader):
    real_data = Variable(batch)
    result = classifier(real_data)
    
    # Calculate error and prediction using your train_classifier
    err, prediction = train_classifier(optimizer, result, real_data, label)
    
    result = result.detach()
    if prediction - result > 0.1:
        mispredicts+=1
    # Update total loss
    total_loss += err

    # Show progress
print(f"avg error: {total_loss/len(test_loader)}")
print(f"mispredictions: {mispredicts} / {len(test_loader)}")

avg error: 0.36457696557044983
mispredictions: 1 / 126
