In [3]:
import torchvision
import  torch.nn as nn
import torch
import torch.nn.functional as F
from torchvision import transforms,models,datasets
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from torch import optim
import os
import shutil
from tqdm import tqdm_notebook as tqdm
from collections import OrderedDict
import random
random.seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_root = '/home/hezhang/workspace/cvpr/data/training_set/'
cats_list = os.listdir(data_root+'training_set/cats')
interest_cats = random.sample(cats_list, int(len(cats_list)/3)) 
for item in interest_cats:
    raw_file = data_root + 'training_set/cats/' + item
    target_file = data_root + 'unbalance_set/cats/' + item
    shutil.copy(raw_file, target_file)

shutil.copytree(data_root+'training_set/dogs/', data_root + 'unbalance_set/dogs/')
test_root = '/home/hezhang/workspace/cvpr/data/test_set/'
test_cats_list = os.listdir(test_root+'test_set/cats/')
test_dogs_list = os.listdir(test_root+'test_set/dogs/')
new_test_cats = random.sample(test_cats_list, 500) 
new_test_dogs = random.sample(test_dogs_list, 500) 

for item in new_test_cats:
    raw_file = test_root + 'test_set/cats/' + item
    target_file = test_root + 'new_test_set/cats/' + item
    shutil.copy(raw_file, target_file)
    
for item in new_test_dogs:
    raw_file = test_root + 'test_set/dogs/' + item
    target_file = test_root + 'new_test_set/dogs/' + item
    shutil.copy(raw_file, target_file)

train_data_dir = '/home/hezhang/workspace/cvpr/data/training_set/unbalance_set'

transform = transforms.Compose([transforms.Resize(255),
                                transforms.CenterCrop(224),
                                transforms.ToTensor()])

dataset = torchvision.datasets.ImageFolder(train_data_dir, transform= transform)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=400 ,shuffle=True)
test_data_dir = '/home/hezhang/workspace/cvpr/data/test_set/new_test_set'

transform = transforms.Compose([transforms.Resize(255),
                                transforms.CenterCrop(224),
                                transforms.ToTensor()])

test_dataset = torchvision.datasets.ImageFolder(test_data_dir, transform= transform)
batch_size = 40
train_data_dir = '/home/hezhang/workspace/cvpr/data/training_set/unbalance_set'
test_data_dir = '/home/hezhang/workspace/cvpr/data/test_set/new_test_set'
transform = transforms.Compose([transforms.Resize(65),
                                transforms.CenterCrop(64),
                                transforms.ToTensor()])

dataset = torchvision.datasets.ImageFolder(train_data_dir, transform= transform)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

test_dataset = torchvision.datasets.ImageFolder(test_data_dir, transform= transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
class ResidualBlock(nn.Module):
    
    def __init__(self,in_channels,out_channels,stride=1,kernel_size=3,padding=1,bias=False):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels,out_channels,kernel_size,1,padding,bias=False),
            nn.BatchNorm2d(out_channels),
        )

        if stride!=1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels,out_channels,kernel_size=1,stride = stride,bias=False),
                nn.BatchNorm2d(out_channels))
        else:
            self.shortcut = nn.Sequential()

    def forward(self,x):
        residual = x
        x = self.conv1(x)
        x = self.conv2(x)
        x += self.shortcut(residual)
        x = nn.ReLU(inplace=True)(x)
        return x
        
class ResNet34(nn.Module):
    def __init__(self,n_classes):
        super().__init__()
        
        self.block1 = nn.Sequential(
            nn.Conv2d(3,64,7,2,3,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.block2 = nn.Sequential(
            nn.MaxPool2d(3,2),
            ResidualBlock(64,64,1),
            ResidualBlock(64,64,1),
            ResidualBlock(64,64,1)
        )
        self.block3 = nn.Sequential(
            ResidualBlock(64,128,1),
            ResidualBlock(128,128,1),
            ResidualBlock(128,128,1),
            ResidualBlock(128,128,2)
        )
        
        self.block4 = nn.Sequential(
            ResidualBlock(128,256,1),
            ResidualBlock(256,256,1),
            ResidualBlock(256,256,1),
            ResidualBlock(256,256,1),
            ResidualBlock(256,256,1),
            ResidualBlock(256,256,2)
        )
        self.block5 = nn.Sequential(
            ResidualBlock(256,512,1),
            ResidualBlock(512,512,1),
            ResidualBlock(512,512,2)
        )
        self.avgpool = nn.AvgPool2d(2)
        self.fc = nn.Linear(512,n_classes)
#         self.fc = nn.Linear(256,n_classes) ## new add

        
    def forward(self,x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.avgpool(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        
        return x
data_root = '/home/hezhang/workspace/cvpr/data/training_set/'
def generateTrainloader_v1(sampling='Under'):
    temp_root = '/home/hezhang/workspace/cvpr/data/training_set/temp_set/'
    try:
        shutil.rmtree(temp_root+'dogs')
        shutil.rmtree(temp_root+'cats')
    except:
        pass
    if sampling == 'Under':
        if not os.path.exists(temp_root+'dogs'):
            os.mkdir(temp_root+'dogs')
        ## sampling 1/3 of dogs
        dogs_list = os.listdir(data_root+'unbalance_set/dogs')
        interest_dogs = random.sample(dogs_list, int(len(dogs_list)/3)) 
        for item in interest_dogs:
            raw_file = data_root + 'unbalance_set/dogs/' + item
            target_file = data_root + 'temp_set/dogs/' + item
            shutil.copy(raw_file, target_file)
        shutil.copytree(data_root+'unbalance_set/cats/', data_root + 'temp_set/cats/')
    if sampling == 'Over':
        if not os.path.exists(temp_root+'cats'):
            os.mkdir(temp_root+'cats')
        ## sampling 3/1 of cats
        cats_list = os.listdir(data_root+'unbalance_set/cats')
        interest_cats = random.choices(cats_list, k=int(len(cats_list)*3)) 
        index = 0
        for item in interest_cats:
            raw_file = data_root + 'unbalance_set/cats/' + item
            target_file = data_root + 'temp_set/cats/' + item
            if (os.path.exists(target_file)) & (target_file[-4:]=='.jpg'):
                target_file = target_file[:-4] + '_' + str(index) + '.jpg'
                index += 1
            shutil.copy(raw_file, target_file)
        shutil.copytree(data_root+'unbalance_set/dogs/', data_root + 'temp_set/dogs/')
    
    train_data_dir = '/home/hezhang/workspace/cvpr/data/training_set/temp_set'

    transform = transforms.Compose([transforms.Resize(65),
                                    transforms.CenterCrop(64),
                                    transforms.ToTensor()])

    dataset = torchvision.datasets.ImageFolder(train_data_dir, transform= transform)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=40 ,shuffle=True)
    
    num_cat = 0
    num_dog = 0
    for item in dataset.imgs:
        if item[1] == 0:
            num_cat += 1
        else:
            num_dog += 1
    print('generated dataset:\nthere are {} cats,\n{} dogs'.format(num_cat, num_dog))
    
    return train_loader

# test_train_loader = generateTrainloader(sampling='Over')       

In [4]:
def train(model, optimizer, criterian, flag, sampling='Under'):
    losses = []
    accs = []
    for epoch in range(epochs):
        train_loader = generateTrainloader_v1(sampling=sampling)
        running_loss = 0.0
        running_acc = 0.0
        model.train()
        for idx, (inputs,labels) in tqdm(enumerate(train_loader),total=len(train_loader)):
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs.float())
            loss = criterion(outputs,labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_acc += (outputs.argmax(1)==labels).float().mean().item()

        test_acc = 0
        with torch.no_grad():
            for idx, (inputs,labels) in tqdm(enumerate(test_loader),total=len(test_loader)):
                inputs = inputs.to(device)
                labels = labels.to(device)
                model.eval()
                outputs = model(inputs.float())
                test_acc += (outputs.argmax(1)==labels).float().mean().item()

        losses.append(running_loss/len(train_loader))
        accs.append(running_acc/len(train_loader))
        print('epochs {}/{} '.format(epoch+1,epochs))
        print('traing_acc : {:.2f}%'.format(running_acc/len(train_loader)))
        print('loss : {:.4f}'.format(running_loss/len(train_loader)))
        print('test_acc : {:.2f}%'.format(test_acc/len(test_loader)))
    torch.save(model.state_dict(), 'output/'+ flag + '.pkl')  
        
model = ResNet34(2).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr = 5e-4)
# class_weights = torch.Tensor([1.5, 0.5]).cuda() 
# criterion = nn.CrossEntropyLoss(reduce='mean',weight=class_weights)
criterion = nn.CrossEntropyLoss(reduce='mean')
epochs = 20
flag = 'resnet_undersampling'
train(model, optimizer, criterion, flag, sampling='Under')

from sklearn.metrics import f1_score, recall_score, precision_score, confusion_matrix
weight_path = 'output/resnet_base.pkl'
model = ResNet34(2).to(device)
model.load_state_dict(torch.load(weight_path))
test_acc = 0
TN,FP,FN,TP = 0, 0, 0, 0
with torch.no_grad():
    for idx, (inputs,labels) in tqdm(enumerate(test_loader),total=len(test_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)
        model.eval()
        outputs = model(inputs.float())
#         print(outputs)
        test_acc += (outputs.argmax(1)==labels).float().mean().item()
        tn, fp, fn, tp = confusion_matrix(outputs.argmax(1).cpu(), labels.cpu()).ravel()
        TN += tn
        FP += fp
        FN += fn
        TP += tp
        print(test_acc)
        print(tn, fp, fn, tp)
        print('--------------------****************---------------')
print(TN, FP, FN, TP)
print('accuracy:\n')
print(test_acc/len(test_loader))
precision = TP/(TP+FP)
recall = TP/(TP+FN)
f1_score = recall*precision/(precision+recall)
print('metrics of dogs:\n')
print('precision', precision, 'recall', recall, 'f1-score', f1_score)
TN_1 = TP
FP_1 = FN
FN_1 = FP
TP_1 = TN
precision = TP_1/(TP_1+FP_1)
recall = TP_1/(TP_1+FN_1)
f1_score = recall*precision/(precision+recall)
print('metrics of cats:\n')
print('precision', precision, 'recall', recall, 'f1-score', f1_score)