In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dataset
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy as d_copy
import random

print(torch.__version__, torchvision.__version__)
print(torch.cuda.get_device_name(0))

1.8.0 0.2.2
TITAN RTX


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#error_index = 0
vgg16_bn = torchvision.models.vgg16_bn(pretrained=True)#.to(device)
print(device)

cuda


In [3]:
In_layer_number = 34 # 34 conv5_1 convolution
Out_layer_number = 36 # 36 conv5_1 relu 
error_index=0
max_epochs = 30
num_error = 128

In [4]:
# randomness 제어 
# https://hoya012.github.io/blog/reproducible_pytorch/
def set_randomness(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
# func

# only apply for feature part (not pooling, classfier)
# because of layers.feature 
def split_layer(model,start,end):
    ct = 0
    split_model=[] # from start to Conv5_1(include ReLU)
    for name,layers in model.named_modules():
        #print(name,layer)
        #print(layers.features)
        for idx,layer in enumerate(layers.features):
            #print(idx,layer)
            if start <=idx and idx <=end :
                split_model.append(layer)
        break
    return nn.Sequential(*split_model)

def error_injection(name,num_error,start_index):
    def hook(model,input):
        start = start_index
        end = start_index + num_error
        input[0][:, start:end]=0
        print("shape :",input[0][:, start:end].size())
    return hook


In [5]:
!ls /media/2/hwbae0326/F4F/0708

acc_log_to34.txt  F4F_pytorch-to36_error_idx_mixed.ipynb


In [6]:
# dataset load
batch_size = 16 # 32~ out of memory in 3080
num_train = 128000
#num_train = 128
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

dataset_path = "/media/2/Network/Imagenet_dup/"
retrain_model_path = "/media/0/Network/0708_to_34models/"
# imagenet data load
train_dataset = dataset.ImageFolder(root=dataset_path+"train",
                                       transform=transform)
subset_train_dataset,_ = torch.utils.data.random_split(train_dataset, [num_train,len(train_dataset)-num_train])

test_dataset = dataset.ImageFolder(root=dataset_path+"val",
                                       transform=transform)
'''
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                        batch_size=64,
                                        shuffle=False,
                                        num_workers=4)
'''

train_dataloader = torch.utils.data.DataLoader(subset_train_dataset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=4) # for using subset
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=4)
print(len(train_dataloader),len(test_dataloader))

8000 3125


In [7]:
if os.path.isdir(retrain_model_path) is False:
    os.mkdir(retrain_model_path)
    

In [8]:
os.listdir(retrain_model_path+"../")

['make_F4F_pytorch-to34_F4Famended.ipynb',
 '.ipynb_checkpoints',
 'extracted_feature',
 'VGG16',
 'start.py',
 'acc_log_to34.txt',
 'make_F4F_pytorch-to34.ipynb',
 '0708_to_34models',
 '0624_to_34models']

In [9]:
seed = 0
set_randomness(seed)

In [10]:
# external variable in error_index, num_error

def make_error_info(error_index, num_error):
    data = []
    for i in range(511,-1,-1):
        if error_index <= i and i < error_index+num_error:
            data.append(1)
        else :
            data.append(0)
        #print(data)
    error_info = torch.Tensor(data)
    error_info  = error_info.unsqueeze(0).repeat(512,1)
    #print("error_info :",error_info)
    return error_info # 512,521
class F4F(nn.Module):
    def __init__(self):
        super().__init__()
        self.f4f = nn.Linear(3*3*512+512,3*3*512) # 4167,4608 filter which change feature.34 (Conv5_1)
        # 512 x5120 사이즈로 batch 저장
    def get_f4f_weight(self):
        # fc.weight.size(),fc.bias.size()
        return self.f4f.weight # torch.Size([4608, 5120])
    def forward(self,x):
        x = self.f4f(x)
        y = torch.tanh(x)
        return y
        

In [11]:
def hook_register(model,error_index,num_error):
    for name,layer in model.named_modules():
        #print(name,layer)
        if "34" in name  and isinstance(layer, torch.nn.modules.conv.Conv2d):
            print("input",name,layer) # target layer Conv5_1
            layer.register_forward_pre_hook(error_injection(name,num_error,error_index))
        break

In [12]:
class Target_model(nn.Module):
    def __init__(self,model):
        super().__init__()
        self.model = model
    def get_layer(self,idx):
        #print(self.model._modules['34'])
        layer =None
        try : # target model
            layer = self.model._modules[str(idx)]
        except KeyError: # test_model
            layer = self.model.features._modules[str(idx)]
        return layer
    def apply_f4f(self,f4f,error_info):
            #print(len(self.get_layer(34).weight.data))
            #print(self.get_layer(34).weight.data.size())
        weight = torch.reshape(self.get_layer(34).weight.data,(512,512*3*3)).to(device) # flatten [512,5210] (batch 512)
            #print(weight.size(),error_info.size())
        data = torch.cat( (weight,error_info), 1 )
            #print(data.size())
        offset = torch.reshape(f4f(data),(512,512,3,3))
        #print("offset is",offset[0][0][0])
        #offset = torch.tanh(offset)
        #print("before apply f4f",self.get_layer(34).weight.data[0][0][0])
        self.get_layer(34).weight.data = self.get_layer(34).weight.data + offset
        #print("after apply f4f",self.get_layer(34).weight.data[0][0][0])
    def forward(self,x,f4f,error_info):
        # apply_f4f는 매 epoch마다 동일하므로 
        self.apply_f4f(f4f,error_info)
        y = self.model(x)
        return y


In [13]:
# evaluation phasetraining
def eval(model,dataloader,epoch,f4f,error_info):

    model.cuda()
    model.eval()
    total = 0
    correct =0
    with torch.no_grad():
        print("======eval start=======")
        for i, data in enumerate(dataloader):
            inputs,labels = data
            inputs,labels = inputs.to(device), labels.to(device)
        
            y_hat = model(inputs,f4f,error_info)
            _, predicted = torch.max(y_hat.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            if(i%200 == 199):
                print("step : %d / %d acc : %.3f"
                      %(i + 1,int(len(dataloader)), correct*100/total))
                #print(".",end="")
        print("")
    acc = 100*correct/total
    print("%dth epoch acc of %s on imagenet : %.4f %%" %(epoch, model.__class__.__name__,acc)) 
    f = open(log_file,"a")
    print("%dth epoch acc of %s on imagenet : %.4f %%" %(epoch, model.__class__.__name__,acc),file=f) 
    f.close()
    print("======eval  end ======")  
    return acc
#torch.save(vgg16_bn.state_dict(), retrain_model_path+"test_vgg16_bn_state_dict.pt")
def model_copy(model):
    return d_copy(model.state_dict())

In [14]:
# training
def training(f4f,test_model,target_model,original_model,
             train_dataloader,test_dataloader,
             loss_fn,optimizer,
             error_idx,num_error,
             max_epochs=30,subset=False):
    
    target_model.to(device)
    original_model.to(device)
    error_info = make_error_info(error_index,num_error).to(device)
    first_feature = []
    original = []
    for epoch in range(max_epochs):
        running_loss = 0.0
        total_avg_loss = 0.0
        print("=====epoch %d start======"%(epoch+1))
        f4f.train()
        # update f4f filter
        #target_model.apply_f4f(f4f,error_info)
    
        # compare
        for i, data in enumerate(train_dataloader):
            print(".",end="")
            inputs,labels = data
            inputs,labels = inputs.to(device), labels.to(device)
            
            original_out = original_model(inputs)
            target_out = target_model(inputs,f4f,error_info)
            #print(original_out[0][0][0],target_out[0][0][0])
            #exit(0)
            if i == 0:
                first_feature.append(target_out[0])
                original.append(original_out[0])
                
            loss = loss_fn(original_out,target_out)
            #print(loss.size())
            running_loss += loss.item()
            target_model.model.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 100 == 99: 
                total_avg_loss += running_loss
                print("")
                print('[%d, %5d] loss: %.6f' % (epoch+1, i+1, running_loss/100)) 
                running_loss = 0.0
        # save weight
        #print((len(train_dataloader)/batch_size))
        total_avg_loss /= int(len(train_dataloader)/batch_size)
        acc = eval(test_model,test_dataloader,epoch,f4f,error_info)
        
        torch.save(f4f.get_f4f_weight(), 
               retrain_model_path+"%s~%s_pkt_err_f4f_epoch_%s_acc_%.4f_loss_%.4f.pt"
               %(str(error_idx).zfill(3),str(error_idx+num_error).zfill(3),
                str(epoch+1).zfill(2),acc,total_avg_loss))    
    return original_out,first_feature
                

In [15]:
log_file = "./acc_log_to34.txt"

In [16]:
split_model = split_layer(vgg16_bn,0,Out_layer_number)

original_model = d_copy(split_model).to(device)
# subset of vgg16 (til 34 layer) with f4f
hook_register(split_model,error_index,num_error)
target_model = Target_model(split_model).to(device)
# whole vgg16 model with f4f
test = d_copy(vgg16_bn).to(device)
hook_register(test,error_index,num_error)
test_model = Target_model(test).to(device)


In [17]:
f4f = F4F().to(device)
optimizer = torch.optim.SGD(f4f.parameters(),lr=0.0005,weight_decay=1e-4)
loss_fn = torch.nn.MSELoss().to(device)

In [None]:

#optimizer = torch.optim.SGD(param_list,lr=0.01,weight_decay=1e-4)
first_feature = []
original_out = []
f = open(log_file,"w")
f.close()
num_error = 128
max_epoch = 30
for epoch in range(max_epoch):
    print("======= epoch %2d ======="%(epoch))
    split_model = split_layer(vgg16_bn,0,Out_layer_number)
    hook_register(split_model,error_index,num_error)
    target_model = Target_model(split_model).to(device)
    test = d_copy(vgg16_bn).to(device)
    hook_register(test,error_index,num_error)
    test_model = Target_model(test).to(device)
    
    for error_idx in range(0,512,num_error):
        tmp= training(f4f,test_model,target_model,original_model,
                  train_dataloader,test_dataloader,
                  loss_fn,optimizer,
                  error_idx,num_error,1,True)
    first_feature.append(tmp[1])
    original_out.append(tmp[0])

....................................................................................................
[1,   100] loss: 0.020667
....................................................................................................
[1,   200] loss: 0.020049
....................................................................................................
[1,   300] loss: 0.020195
....................................................................................................
[1,   400] loss: 0.020465
....................................................................................................
[1,   500] loss: 0.020689
....................................................................................................
[1,   600] loss: 0.020726
.........

In [None]:
# 여기서부터는 feature 그림 보기 위한 것들입니다.
len(first_feature),len(original_out)

In [None]:
len(first_feature[0])

In [None]:
print(original_out[0].size())
w = 10
h = 10
cols = 32
rows = 16
def feature_print(pic):
    print("test with 'after pooling 4 feature'")
    fig = plt.figure(figsize=(64,32))
    ax = []
    for i in range(cols*rows):
        ch = pic[i,:,:]
        ax.append(fig.add_subplot(rows,cols,i+1))
        ax[-1].set_title(str(i)+"th ch (14x14)")
        plt.imshow(ch)

In [None]:
# 기존 모델 (에러없이, f4f없이)을 통과한 결과
feature_print(original_out[0][0].cpu().detach().numpy())

In [None]:
# f4f을 통과한 결과  epoch 1
%matplotlib inline
feature_print(first_feature[0][0].cpu().detach().numpy())

In [None]:
# f4f을 통과한 결과  epoch 9
print("epoch 9")
feature_print(first_feature[0][9].cpu().detach().numpy())

In [None]:
# 14x14 의 feature 모두 합한 결과
tmp = first_feature[0][6][0]
for i in range(1,512):
    tmp += first_feature[0][6][i]
%matplotlib inline
plt.imshow(tmp.cpu().detach())


In [None]:
# 14x14 의 feature 모두 합한 결과
print("original")
tmp1 = original_out[0][6][0]
for i in range(1,512):
    tmp1 += original_out[0][6][i]
%matplotlib inline
plt.imshow(tmp1.cpu().detach())

In [None]:
# f4f을 통과한 결과  epoch 6
print("epoch 6")
feature_print(first_feature[0][6].cpu().detach().numpy())

In [None]:
running_loss = 0.0
error_info = make_error_info(error_index,num_error).to(device)
for feature,label in train_dataloader:
    print("======================================")
    feature,label = feature.to(device),label.to(device)
    print("dataloader data : ",feature.size(),label.size())
    target_out = test_model(feature,f4f,error_info)
    print("output :",target_out.size())
    original_out = original_model(feature)
    if torch.equal(target_out,original_out) is False :
        print("=====compare two output======")
        print(target_out[0][0][0][0])
        print(original_out[0][0][0][0])
    else :
        print("same")
    loss = loss_fn(original_out,target_out)
    running_loss += loss.item()
    target_model.model.zero_grad()
    loss.backward()
    optimizer.step()
    break
#target_model