In [3]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dataset
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy as d_copy
import random

print("===INFO===")
print("torch ver : %s\ntorchvision ver : %s " %(torch.__version__, torchvision.__version__))
print("GPU model :",torch.cuda.get_device_name(0))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

===INFO===
torch ver : 1.8.0
torchvision ver : 0.2.2 
GPU model : TITAN RTX


In [4]:
# randomness 제어 
# https://hoya012.github.io/blog/reproducible_pytorch/
def set_randomness(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
# func

# only apply for feature part (not pooling, classfier)
# because of layers.feature 
def split_layer(model,start,end):
    ct = 0
    split_model=[] # from start to Conv5_1(include ReLU)
    for name,layers in model.named_modules():
        #print(name,layer)
        #print(layers.features)
        for idx,layer in enumerate(layers.features):
            #print(idx,layer)
            if start <=idx and idx <=end :
                split_model.append(layer)
        break
    return nn.Sequential(*split_model)

def error_injection(name,num_error,start_index):
    def hook(model,input):
        start = start_index
        end = start_index + num_error
        #print(input.size()) #not working
        input[0][:, start:end]=0
        #print("shape :",input[0][:, start:end].size())
    return hook


In [5]:
#set_randomness(0)

In [3]:

def get_dataset(num_train,batch_size,
                dataset_path,retrain_model_path):
    if os.path.isdir(retrain_model_path) is False:
        # make folder
        os.mkdir(retrain_model_path)
        print("retrain model path created :",os.listdir(retrain_model_path+"../"))
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transforms_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    transforms_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = dataset.ImageFolder(root=dataset_path+"train",
                                       transform=transforms_train)
    subset_train_dataset,_ = torch.utils.data.random_split(train_dataset, 
                                        [num_train,len(train_dataset)-num_train])
    test_dataset = dataset.ImageFolder(root=dataset_path+"val",
                                       transform=transforms_test)
    
    train_dataloader = torch.utils.data.DataLoader(subset_train_dataset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=4) # for using subset
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=4)
    print("train dataset[%d], test dataset[%d] are loaded"%(len(train_dataloader),len(test_dataloader)))
    return train_dataloader,test_dataloader

In [4]:
# external variable in error_index, num_error

def make_error_info(error_index, num_error):
    data = []
    #for i in range(511,-1,-1):
    for i in range(512):
        if error_index <= i and i < error_index+num_error:
            data.append(1)
        else :
            data.append(0)
        #print(data)
    error_info = torch.Tensor(data)
    error_info  = error_info.unsqueeze(0).repeat(512,1)
    #print("error_info :",error_info)
    return error_info # 512,521
class F4F(nn.Module):
    def __init__(self):
        super().__init__()
        #self.f4f = nn.Linear(3*3*512+512,3*3*512) # 5120,4608 filter which change feature.34 (Conv5_1)
        self.Ffc1 = nn.Linear(3*3*512,3*3*512,bias=True)
        
        self.Efc1 = nn.Linear(512,512,bias=True)
        self.Efc2 = nn.Linear(512,128,bias=True)
        
        self.fc1 = nn.Linear(3*3*512+128,4736,bias=True) # 4,736 -> 4,736
        self.fc2 = nn.Linear(4736,3*3*512,bias=True)
        
        #nn.init.xavier_normal_(self.fc1.weight)
        #nn.init.xavier_normal_(self.Ffc1.weight)
        #nn.init.xavier_normal_(self.Efc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)
        #nn.init.xavier_normal_(self.Efc2.weight)
        # 512 x5120 사이즈로 batch 저장
    def get_f4f_weight(self):
        # fc.weight.size(),fc.bias.size()
        return self.parameters # torch.Size([4608, 5120])
    def forward(self,x,filter_set,error_info):
        
        A1 = self.Ffc1(filter_set)
        A1 = torch.relu(A1)
        
        B1 = self.Efc1(error_info)
        B1 = torch.relu(B1)
        B2 = self.Efc2(B1)
        B2 =  torch.relu(B2)
        
        x = torch.cat((A1,B2),1)
        x1 = self.fc1(x)
        x1 = torch.relu(x1)
        x2 = self.fc2(x1)
        y = torch.tanh(x2)
        return y
"""
        data = torch.cat( (error_info,weight), 1 ) #210801 error_info를 앞에 붙이는 방법
        offset = torch.reshape(f4f(data),(512,512,3,3))
        #self.get_layer(34).weight.data = self.get_layer(34).weight.data + offset
        
        self.get_layer(34).weight.data = offset
"""

In [5]:
def hook_register(model):
    for name,layer in model.named_modules():
        #print(name,layer)
        if "34" in name  and isinstance(layer, torch.nn.modules.conv.Conv2d):
            print("input",name,layer) # target layer Conv5_1
            layer.register_forward_pre_hook(error_injection(name))
        break

In [10]:
class Target_model(nn.Module):
    def __init__(self,model):
        super().__init__()
        self.model = model
    def get_layer(self,idx):
        #print(self.model._modules['34'])
        layer =None
        try : # target model
            layer = self.model._modules[str(idx)]
        except KeyError: # test_model
            layer = self.model.features._modules[str(idx)]
        return layer
    def apply_f4f(self,x,f4f,error_info):
        weight = torch.reshape(self.get_layer(34).weight.data,(512,512*3*3)).to(device) # flatten [512,4608] (batch 512)
        #print("before size : ",weight.size())
        modified_weight = f4f(x,weight,error_info)
        modified_weight = torch.reshape(modified_weight,(512,512,3,3))
        #print("size : ",modified_weight.size()) 
        self.get_layer(34).weight.data = modified_weight
    def forward(self,x,f4f,error_info):
        origin_weight = self.get_layer(34).weight.data
        self.apply_f4f(x,f4f,error_info)
        replace_weight = self.get_layer(34).weight.data
        y = self.model(x)
        return y, origin_weight, replace_weight


In [7]:
# evaluation phasetraining
def eval(model,dataloader,epoch,batch_size,
         loss_fn,f4f,error_info,log_file):
    
    
    model.cuda()
    model.eval()
    total = 0
    correct =0
    total_loss =0.0
    with torch.no_grad():
        print("======eval start=======")
        for i, data in enumerate(dataloader):
            inputs,labels = data
            inputs,labels = inputs.cuda(), labels.cuda()
        
            #y_hat = model(inputs,f4f,error_info)
            result = model(inputs,f4f,error_info)
            y_hat,origin_weight, replace_weight = result
            
            _, predicted = torch.max(y_hat, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            loss = loss_fn(0.5,y_hat,labels,
                            origin_weight, replace_weight)
            total_loss += loss.item()
            
            if(i%200 == 199):
                print("step : %d / %d acc : %.3f"
                      %(i + 1,int(len(dataloader)), correct*100/total))
                #print(".",end="")
        print("")
    acc = 100*correct/total
    avg_loss = total_loss / (len(dataloader)*batch_size)
    print("Eval acc of model on imagenet : %.4f %%, Loss : %.4f" %(acc,avg_loss)) # model.__class__.__name__
    f = open(log_file,"a")
    print("Eval acc of model on imagenet : %.4f %%, Loss : %.4f" %(acc,avg_loss),file=f) # model.__class__.__name__
    f.close()
    print("======eval  end ======")  
    return acc


In [8]:
# training
def training(f4f,target_model,original_model,
             train_dataloader,test_dataloader,batch_size,
             log_file,retrain_model_path,
             loss_fn,optimizer,
             num_error,
             max_epochs=30,subset=False):
    first_feature = []
    first_label = []
    original_out = []
    offset_info = []
    target_model.to(device)
    original_model.to(device)
    
    
    feature_num = 100
    for epoch in range(max_epochs):
        running_loss = 0.0
        total_loss = []
        total_avg_loss = 0.0
        total = 0
        correct = 0
        f4f.train()
        # update f4f filter
        #target_model.apply_f4f(f4f,error_info)
    
        # compare
        for i, data in enumerate(train_dataloader):
            error_index = i % 512
            error_info = make_error_info(error_index,num_error).to(device)
            
            
            if i % 10 == 0:
                print(".",end="")
            inputs,labels = data
            inputs,labels = inputs.to(device), labels.to(device)
            result = target_model(inputs,f4f,error_info)
            target_out,origin_weight, replace_weight = result
            #print(original_out[0][0][0],target_out[0][0][0])
            #exit(0)
            
            if len(first_feature) < feature_num:
                first_feature.append(target_out)
                first_label.append(labels)
                #first_feature.pop(0)
                #first_label.pop(0)
            _,predicted = torch.max(target_out,1) # target_out.data : no grad, target_out : with grad
            
            total += labels.size(0)
            correct += (predicted==labels).sum().item()
            #dprint(target_out.size(),labels.size(),origin_weight.size(), replace_weight.size())
            #print(labels.size(),target_out.size())
            loss = loss_fn(0.5,target_out,labels,
                            origin_weight, replace_weight)
            
            running_loss += loss.item()
            #target_model.model.zero_grad() # might be useless
            f4f.zero_grad()
            loss.backward()
            optimizer.step()
            
            if i % 100 == 99: 
                total_loss.append(running_loss/100)
                print("")
                print('[%d, %5d] loss: %.6f' % (epoch+1, i+1, running_loss/100)) 
                running_loss = 0.0
        # save weight
        if len(total_loss) != 0:
            total_avg_loss = sum(total_loss)/(len(total_loss)*batch_size)
        acc = 100*correct/total
        if total_avg_loss != 0:
            print("total average loss : %.3f" %(total_avg_loss))
        else :
            print("total loss :" ,total_loss)
        print("== epoch %2d == train acc : %.4f" %(epoch,acc))
        acc = eval(target_model,test_dataloader,epoch,batch_size,
                   loss_fn,f4f,error_info,log_file)
        
        offset_info.append(target_model.get_layer(34))
        #torch.save(f4f.get_f4f_weight(), 
        #       retrain_model_path+"%s~%s_pkt_err_f4f_epoch_%s_acc_%.4f_loss_%.4f.pt"
        #       %(str(error_idx).zfill(3),str(error_idx+num_error).zfill(3),
        #        str(epoch+1).zfill(2),acc,total_avg_loss))    
    return first_feature,first_label,offset_info
                