In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
from os.path import join as pj
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.models as tv_models
import visdom
%matplotlib inline

# Train Config

In [None]:
class args:
    # experiment_name
    experiment_name = "insect2vec_b20_epoch1000"
    # paths
    all_data_path = "/home/tanida/workspace/Insect_project/data/all_classification_data/classify_insect_std"
    model_save_path_root = pj("/home/tanida/workspace/Insect_project/model/classification/insect2vec", experiment_name)
    semantic_save_path_root = pj("/home/tanida/workspace/Insect_project/data/insect_semantic_vector", experiment_name)
    # train config
    vector_length = 100
    bs = 20
    num_workers = 20
    lr = 1e-5
    alpha = 1e-2
    lamda = 1e-1
    save_interval = 100
    nepoch = 1000
    # visdom
    visdom = True
    port = 8098

# Logger

In [None]:
class Logger(object):
    def __init__(self, file_root, filename):
        self.file_path = pj(file_root, filename)
        if os.path.exists(file_root) is False:
            os.makedirs(file_root)

    def write(self, msg):
        if self.file_path is not None:
            with open(self.file_path, "a") as f:
                f.write(msg)

def generate_args_map(args):
    import re
    args_keys_list = list(args.__dict__.keys())
    args_values_list = list(args.__dict__.values())

    pattern = r"__"
    refined_args_map = {}
    for i, args_key in enumerate(args_keys_list):
        is_meta = re.match(pattern, args_key)
        if is_meta is None:
            refined_args_map.update({args_keys_list[i]:args_values_list[i]})
    return refined_args_map

def save_experiment_args(args_logger, args):
    args_logger.write("\nTraining on: " + args.experiment_name + "\n")
    args_logger.write("Using the specified args:"+"\n")
    for k,v in args_map.items():
        args_logger.write(str(k)+": "+str(v)+"\n")

# Visualize

In [None]:
if args.visdom:
    # create visdom
    vis = visdom.Visdom(port=args.port)
    
    win_match_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='match_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_norm_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='norm_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_all_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='train_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )

In [None]:
def visualize(vis, phase, visualized_data, window):
    vis.line(
        X=np.array([phase]),
        Y=np.array([visualized_data]),
        update='append',
        win=window
    )

# Dataset

In [None]:
class insect_dataset(data.Dataset):
    def __init__(self, data_path):
        images, labels, class_num, class_count = self.load_data(data_path)
        self.images = torch.from_numpy(images).transpose(1, -1).float()
        self.labels = torch.from_numpy(labels)
        self.class_num = class_num
        self.class_count = class_count
    
    def __getitem__(self, index):
        image, label = self.images[index], self.labels[index]
        return image, label
    
    def __len__(self):
        return self.images.shape[0]
    
    def load_data(self, data_path):
        with h5py.File(data_path) as f:
            X = f["X"][:]
            Y = f["Y"][:]
        idx, count = np.unique(Y, return_counts=True)
        return X, Y, len(idx), count

# Model

In [None]:
class img2vec(nn.Module):
    def __init__(self, img_size, training=True, vector_length=args.vector_length):
        super(img2vec, self).__init__()
        last_pool_size = img_size
        for i in range(4):
            last_pool_size = (int)(last_pool_size / 2)
        self.training = training
        self.vector_length = vector_length
        # encoder
        self.conv11 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.maxpool1 = nn.MaxPool2d(2, stride=2, return_indices=True)
        self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(2, stride=2, return_indices=True)
        self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.maxpool3 = nn.MaxPool2d(2, stride=2, return_indices=True)
        self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, return_indices=True)
        self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv53 = nn.Conv2d(512, vector_length, kernel_size=3, padding=1)
        self.maxpool5 = nn.MaxPool2d(last_pool_size, return_indices=True)
        
        # decoder
        self.maxunpool5 = nn.MaxUnpool2d(last_pool_size)
        self.conv53d = nn.Conv2d(vector_length, 512, kernel_size=3, padding=1)
        self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.maxunpool4 = nn.MaxUnpool2d(2, stride=2)
        self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.maxunpool3 = nn.MaxUnpool2d(2, stride=2)
        self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv31d = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.maxunpool2 = nn.MaxUnpool2d(2, stride=2)
        self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.maxunpool1 = nn.MaxUnpool2d(2, stride=2)
        self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv11d = nn.Conv2d(64, 3, kernel_size=3, padding=1)
    
    def forward(self, x):
        # encoder
        x = F.relu(self.conv11(x))
        x = F.relu(self.conv12(x))
        size1b = x.size()
        x, indices1b = self.maxpool1(x)
        x = F.relu(self.conv21(x))
        x = F.relu(self.conv22(x))
        size2b = x.size()
        x, indices2b = self.maxpool2(x)
        x = F.relu(self.conv31(x))
        x = F.relu(self.conv32(x))
        x = F.relu(self.conv33(x))
        size3b = x.size()
        x, indices3b = self.maxpool3(x)
        x = F.relu(self.conv41(x))
        x = F.relu(self.conv42(x))
        x = F.relu(self.conv43(x))
        size4b = x.size()
        x, indices4b = self.maxpool4(x)
        x = F.relu(self.conv51(x))
        x = F.relu(self.conv52(x))
        x = F.relu(self.conv53(x))
        size5b = x.size()
        x, indices5b = self.maxpool5(x)
        
        # decoder
        if self.training is True:
            x = self.maxunpool5(x, indices5b, output_size=size5b)
            x = F.relu(self.conv53d(x))
            x = F.relu(self.conv52d(x))
            x = F.relu(self.conv51d(x))
            x = self.maxunpool4(x, indices4b, output_size=size4b)
            x = F.relu(self.conv43d(x))
            x = F.relu(self.conv42d(x))
            x = F.relu(self.conv41d(x))
            x = self.maxunpool3(x, indices3b, output_size=size3b)
            x = F.relu(self.conv33d(x))
            x = F.relu(self.conv32d(x))
            x = F.relu(self.conv31d(x))
            x = self.maxunpool2(x, indices2b, output_size=size2b)
            x = F.relu(self.conv22d(x))
            x = F.relu(self.conv21d(x))
            x = self.maxunpool1(x, indices1b, output_size=size1b)
            x = F.relu(self.conv12d(x))
            x = F.relu(self.conv11d(x))
        return x
    
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)

# Train and Test

In [None]:
l2_loss = nn.MSELoss(reduction='elementwise_mean').cuda()

In [None]:
def train_per_epoch(model, data_loader, optimizer, epoch, use_visdom, alpha=1e-3, lamda=1e-1):
    model.train()
    total_match_loss = 0
    total_norm_loss = 0
    
    # train
    for image, label in tqdm(data_loader, leave=False):
        #print(image.shape)
        image = image.cuda()
        
        # forward
        output = model(image)
        
        # calculate loss
        optimizer.zero_grad()
        match_loss = l2_loss(output, image)
        match_loss = alpha * match_loss
        loss = match_loss
        
        if args.lamda != 0:
            norm_loss = 0
            for param in model.parameters():
                param_target = torch.zeros(param.size()).cuda()
                norm_loss += l2_loss(param, param_target)

            norm_loss = lamda * norm_loss
            loss += norm_loss
        else:
            norm_loss = 0
        
        loss.backward()
        optimizer.step()
        total_match_loss += match_loss.item()
        total_norm_loss += norm_loss.item()
    
    print('epoch ' + str(epoch) + ' || MATCH Loss: %.4f NORM Loss: %.4f ||' % (total_match_loss, total_norm_loss))
    
    if use_visdom:
        visualize(vis, epoch+1, total_match_loss, win_match_loss)
        visualize(vis, epoch+1, total_norm_loss, win_norm_loss)
        visualize(vis, epoch+1, total_match_loss + total_norm_loss, win_all_loss)

### Save args

In [None]:
args_map = generate_args_map(args)
args_logger = Logger(args.model_save_path_root, "args.txt")
save_experiment_args(args_logger, args)

### Make data

In [None]:
train_dataset = insect_dataset(args.all_data_path)
train_data_loader = data.DataLoader(train_dataset, args.bs, num_workers=args.num_workers, shuffle=True)

### Make model

In [None]:
model = img2vec(200, training = True).cuda()
model.apply(weights_init)
print(model)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# Train

In [None]:
for epoch in range(args.nepoch):
    train_per_epoch(model, train_data_loader, optimizer, epoch, args.visdom, alpha=args.alpha, lamda=args.lamda)
    
    # save model
    if epoch != 0 and epoch % args.save_interval == 0:
        print('Saving state, epoch: ' + str(epoch))
        torch.save(model.state_dict(), pj(args.model_save_path_root, 'insect2vec_{}.pth'.format(str(epoch))))

# save final model
print('Saving state, final')
torch.save(model.state_dict(), pj(args.model_save_path_root, 'insect2vec_final.pth'))

# Output semantic vector

In [None]:
def output_semantic_vector(model, data_loader, class_num, class_count, vector_length):
    model.eval()
    semantic_vector = np.zeros((class_num, vector_length))

    for image, label in tqdm(data_loader, leave=False):
        image = image.cuda()
        output_vector = model(image)
        output_vector = torch.squeeze(output_vector)
        output_vector = F.normalize(output_vector, p=2, dim=0)
        output_vector = output_vector.cpu().detach().numpy()
        semantic_vector[label] += output_vector
    
    for i,count in enumerate(class_count):
        semantic_vector[i] = semantic_vector[i] / count
    return semantic_vector

def write_semantic_vector(semantic_vector, semantic_save_path_root, vector_length):
    save_path = pj(semantic_save_path_root, "vectors.txt")
    semantic_string = ""
    for vector in semantic_vector:
        for i,num in enumerate(vector):
            if i == vector_length - 1:
                semantic_string += str(num) + "\n"
            else:
                semantic_string += str(num) + " "
    if os.path.exists(semantic_save_path_root) is False:
        os.makedirs(semantic_save_path_root)
    with open(save_path, mode="w") as f:
        f.write(semantic_string)

In [None]:
test_dataset = insect_dataset(args.all_data_path)
class_num = test_dataset.class_num
class_count = test_dataset.class_count
test_data_loader = data.DataLoader(test_dataset, 1, num_workers=args.num_workers, shuffle=True)

In [None]:
model = img2vec(200, training = False).cuda()
model.load_state_dict(torch.load(pj(args.model_save_path_root, 'insect2vec_final.pth')))

In [None]:
semantic_vector = output_semantic_vector(model, test_data_loader, class_num, class_count, args.vector_length)

In [None]:
write_semantic_vector(semantic_vector, args.semantic_save_path_root, args.vector_length)