In [1]:
# case 6 of dongwhee
import os
import numpy as np
import random
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torch.utils.data as data
import torchvision
from torchvision import transforms
from torchvision import models
import torchvision.models as models
import matplotlib.pyplot as plt
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'

In [2]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)
random.seed(0)

pretrained_model = models.vgg16_bn(pretrained=True)
if torch.cuda.device_count() > 1:
    pretrained_model = nn.DataParallel(pretrained_model)
pretrained_model.cuda() 
#print(pretrained_model)

new_model=models.vgg16_bn(pretrained=True).cuda()
if torch.cuda.device_count() > 1:
    new_model = nn.DataParallel(new_model)


In [3]:

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
transforms_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

transforms_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224), 
    transforms.ToTensor(),
    normalize,
])

In [4]:
file = open('Filter_for_Filter_gpu1_result.txt', 'w')    
num_epochs = 120
batchsize = 64
lr = 0.001
class_num=1000 
channel_per_packet=2 
packet_loss_per_feature=64
before_accuracy=0.0
before_lr=lr

In [5]:
TRAIN_DATA_PATH = "/media/2/Network/Imagenet_dup/train"
TEST_DATA_PATH="/media/2/Network/Imagenet_dup/val"
SAVE_PATH="/media/0/Network/0722_dongwhee"

In [6]:
if os.path.isdir(SAVE_PATH) is False:
    os.mkdir(SAVE_PATH+"/only_error_input_F4F")
    os.mkdir(SAVE_PATH+"/F4F_weight")

In [7]:
trainset = torchvision.datasets.ImageFolder(root=TRAIN_DATA_PATH, transform=transforms_train) 
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batchsize, shuffle=False,num_workers=4)
testset = torchvision.datasets.ImageFolder(root=TEST_DATA_PATH, transform=transforms_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batchsize, shuffle=False, num_workers=4)

In [8]:

loss_start_index=0
# hook 설정!

activation_input = {}
def preprocessing(name): 
    def hook(model, input): 
        input[0][:,loss_start_index:loss_start_index+128] = 0 
        activation_input[name] = input[0]
    return hook

class F4F_only_error_index(nn.Module): 
    def __init__(self):
        super(F4F_only_error_index, self).__init__()
        #self.fc1=nn.Linear(512, 3 * 3 * 512)
        self.fc1=nn.Linear(512+3*3*512, 3 * 3 * 512)
        self.f4f_optimizer = torch.optim.SGD([self.fc1.weight, self.fc1.bias],lr=lr,weight_decay=1e-4)
        
    def forward(self,x): 
        x=self.fc1(x) 
        output=torch.tanh(x)
        return output

activation1 = {}
def get_activation1(name):
    def hook(model, input, output):
        activation1[name] = output.detach()
    return hook
        
activation2 = {}
def get_activation2(name): 
    def hook(model, input, output):
        activation2[name] = output
    return hook

def error_index_make(loss_start_index):
    error_index=[]
    for index in range(512):
        if loss_start_index<=index and index < loss_start_index+128: 
            error_index.append(1)
        else:
            error_index.append(0)
    error_index=torch.Tensor(error_index)
    error_index=error_index.unsqueeze(0).repeat(512,1)
    #print(error_index.size())
    return error_index

In [9]:

F4F=F4F_only_error_index() 
F4F.cuda()
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.SGD(F4F.parameters(),lr=lr,weight_decay=1e-4)
F4F.train()

#### 기존 모델에서 forward hook 통해서 conv5_1의 output 저장 ###
for name, module in pretrained_model.named_modules():
    if name=="features.34": 
        pretrained_model.features[34].register_forward_hook(get_activation1(name))
        break

##### 새로운 모델에서 pre hook ##### (error input 넣기)
for name, module in new_model.named_modules():
    if name=="features.34": 
        module.register_forward_pre_hook(preprocessing(name))
        break
    
#### 새로운 모델에서 conv5_1의 결과 뽑기 ####
for name, module in new_model.named_modules():
    if name=="features.34": # 
        new_model.features[34].register_forward_hook(get_activation2(name))
        break

#### 원래 parameter 저장 ####
original_parameter = torch.empty(512,512,3,3) 
for name, parameter in new_model.named_parameters():
    if name == 'features.34.weight': 
        original_parameter=new_model.features[34].weight.data
        original_bias     =new_model.features[34].bias.data
        break        
original_parameter = original_parameter.cuda()
original_bias      = original_bias.cuda()

In [None]:

##### epoch 시작 #####
loss_start_index=0
for epoch in range(num_epochs): 

    ####### train #######
    running_loss=0.0
    correct_top1 = 0
    total = 0
    for idx, (images, labels) in enumerate(tqdm(trainloader,desc=f'EPOCH {epoch} ')):
        error_index=error_index_make(loss_start_index)
        error_index=error_index.cuda()
        
        #offset = F4F(error_index)
        #offset=torch.reshape(offset,[512,3,3])  
        #test_filter = original_parameter+offset
        test_filter = torch.empty(512,512,3,3)
        for name, parameter in new_model.named_parameters():
            if name == 'features.34.weight':
                weight = torch.reshape(original_parameter,(512,512*3*3)).to(device)
                #print(weight.size(),error_index.size())
                data = torch.cat( (weight,error_index), 1)
                offset = torch.reshape(F4F(data),(512,512,3,3))
                new_model.features[34].weight.data = original_parameter[:, :]+offset
                test_filter = new_model.features[34].weight
                #print(new_model.features[34].weight)
                break
        
        images = images.cuda()
        labels = labels.cuda()
        with torch.no_grad():
            out1=pretrained_model(images)
        out2=new_model(images)
        test_output = torch.nn.functional.conv2d(activation_input['features.34'].detach(), test_filter, original_bias, 1, 1)

        optimizer.zero_grad()
        loss = criterion(test_output,activation1['features.34']) 
        running_loss+=loss.item()
        loss.backward() 
        optimizer.step()
        _, predicted=torch.max(out2,1)
        total += labels.size(0) 
        correct_top1 += (predicted == labels).sum().item()
        #if idx>1700: 
        #    break
    print("train accuracy : {0:0.2f}%\n".format(correct_top1/total*100))
    file.write("===== {0}th Epoch ======\n".format(epoch+1))
    file.write("loss : {0}\n".format(running_loss))
    file.write("train accuracy : {0:0.2f}%\n".format(correct_top1/total*100))  
        
    #### val #####
    for name, parameter in new_model.named_parameters():
        if name == 'features.34.weight':
            weight = torch.reshape(original_parameter,(512,512*3*3)).to(device)
                #print(weight.size(),error_index.size())
            data = torch.cat( (weight,error_index), 1)
            offset = torch.reshape(F4F(data),(512,512,3,3))
            new_model.features[34].weight.data = original_parameter[:, :]+offset
            test_filter = new_model.features[34].weight
            #print(new_model.features[34].weight)
            break
    new_model.eval()

    correct_top1 = 0
    total = 0
    with torch.no_grad(): 
        for idx, (images, labels) in enumerate(testloader):
            images = images.cuda()
            labels = labels.cuda()
            outputs = new_model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0) 
            correct_top1 += (predicted == labels).sum().item()
    print("val top-1 accuracy : {0:0.2f}%\n".format(correct_top1 / total * 100))
    file.write("error channel {0}~{1}, epoch : [{2}/{3}]\n".format(loss_start_index,loss_start_index+128-1, epoch+1, num_epochs))        
    file.write("val top-1 accuracy :  {0:0.2f}%\n".format(correct_top1 / total * 100))
    # scheduler 부분
    if (correct_top1 / total * 100) < before_accuracy:
        optimizer = torch.optim.SGD(F4F.parameters(),lr=before_lr*0.5,weight_decay=1e-4)
        before_lr=before_lr*0.8
    before_accuracy=(correct_top1 / total * 100)
    
            
file.close() 

EPOCH 0 : 100%|██████████| 20019/20019 [2:30:54<00:00,  2.21it/s]  

train accuracy : 0.14%




EPOCH 1 :   0%|          | 0/20019 [00:00<?, ?it/s]

val top-1 accuracy : 0.68%



EPOCH 1 : 100%|██████████| 20019/20019 [2:29:10<00:00,  2.24it/s]  

train accuracy : 0.83%




EPOCH 2 :   0%|          | 0/20019 [00:00<?, ?it/s]

val top-1 accuracy : 0.68%



EPOCH 2 : 100%|██████████| 20019/20019 [2:28:56<00:00,  2.24it/s]  

train accuracy : 0.82%




EPOCH 3 :   0%|          | 0/20019 [00:00<?, ?it/s]

val top-1 accuracy : 0.68%



EPOCH 3 : 100%|██████████| 20019/20019 [2:29:00<00:00,  2.24it/s]  

train accuracy : 0.85%




EPOCH 4 :   0%|          | 0/20019 [00:00<?, ?it/s]

val top-1 accuracy : 0.68%



EPOCH 4 : 100%|██████████| 20019/20019 [2:29:16<00:00,  2.24it/s]  

train accuracy : 0.84%




EPOCH 5 :   0%|          | 0/20019 [00:00<?, ?it/s]

val top-1 accuracy : 0.68%



EPOCH 5 : 100%|██████████| 20019/20019 [2:29:23<00:00,  2.23it/s]  

train accuracy : 0.84%




EPOCH 6 :   0%|          | 0/20019 [00:00<?, ?it/s]

val top-1 accuracy : 0.68%



EPOCH 6 : 100%|██████████| 20019/20019 [2:29:28<00:00,  2.23it/s]  

train accuracy : 0.83%




EPOCH 7 :   0%|          | 0/20019 [00:00<?, ?it/s]

val top-1 accuracy : 0.68%



EPOCH 7 : 100%|██████████| 20019/20019 [2:29:10<00:00,  2.24it/s]  

train accuracy : 0.84%




EPOCH 8 :   0%|          | 0/20019 [00:00<?, ?it/s]

val top-1 accuracy : 0.68%



EPOCH 8 :  47%|████▋     | 9317/20019 [1:09:14<1:19:35,  2.24it/s]

In [None]:
!ls media/0/Network/0722_dongwhee//only_error_input_F4F/

In [None]:
!ls media/0/Network/0722_dongwhee/only_error_input_F4F/

In [None]:
os.mkdir("media/0/Network/0722_dongwhee/only_error_input_F4F")

In [None]:
'''
POCH 0 : 100%|██████████| 20019/20019 [2:54:46<00:00,  1.91it/s]  
train accuracy : 7.38%


EPOCH 1 :   0%|          | 0/20019 [00:00<?, ?it/s]
val top-1 accuracy : 66.51%

EPOCH 1 : 100%|██████████| 20019/20019 [2:56:36<00:00,  1.89it/s]  
train accuracy : 63.75%


EPOCH 2 :   0%|          | 0/20019 [00:00<?, ?it/s]
val top-1 accuracy : 66.45%

EPOCH 2 : 100%|██████████| 20019/20019 [2:56:45<00:00,  1.89it/s]  
train accuracy : 64.26%


EPOCH 3 :   0%|          | 0/20019 [00:00<?, ?it/s]
val top-1 accuracy : 66.42%

EPOCH 3 : 100%|██████████| 20019/20019 [2:55:56<00:00,  1.90it/s]  
train accuracy : 64.20%


EPOCH 4 :   0%|          | 0/20019 [00:00<?, ?it/s]
val top-1 accuracy : 66.48%

EPOCH 4 : 100%|██████████| 20019/20019 [2:57:15<00:00,  1.88it/s]  
train accuracy : 64.30%


EPOCH 5 :   0%|          | 0/20019 [00:00<?, ?it/s]
val top-1 accuracy : 66.45%

EPOCH 5 : 100%|██████████| 20019/20019 [2:56:55<00:00,  1.89it/s]  
train accuracy : 64.24%


EPOCH 6 :   0%|          | 0/20019 [00:00<?, ?it/s]
val top-1 accuracy : 66.47%

EPOCH 6 : 100%|██████████| 20019/20019 [2:55:25<00:00,  1.90it/s]  
train accuracy : 64.27%


EPOCH 7 :   0%|          | 0/20019 [00:00<?, ?it/s]
val top-1 accuracy : 66.47%
'''