In [2]:
from TripletLossFunc import TripletLossFunc
import random

import torch
import torch.nn.functional as F

import torchvision.models as models

from Utils import *
from Define import *
from NUS_WIDE_Helper import *

from jupyterplot import ProgressPlot
from tqdm.notebook import tqdm
from TenNetImage import *
from TenNetTag import *
from TagDecoder import *
import TenNetModel as TenNet
import TenDecoderModel as D

In [3]:
getTenNetFromFile = False
getDecoderFromFile = False

In [None]:
#NUS_WIDE
train_data = NUS_WIDE_Helper(DataSetType.Train_81,)
valid_data = NUS_WIDE_Helper(DataSetType.Test_81,  Number_Of_Images_Valid)
#test_data = NUS_WIDE_Helper(DataSetType.Test_81)

batch_size = BATCH_SIZE
train_loader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(valid_data, shuffle=True, batch_size=batch_size)
#test_loader = torch.utils.data.DataLoader(test_data, shuffle=True, batch_size=batch_size)

In [None]:
#image_model = models.vgg16(pretrained=True).to(device)
image_model = TenNet_Image().to(device)
tag_model = TenNet_Tag(train_data.get_tag_list()).to(device)

optim = torch.optim.Adam([{'params' : image_model.parameters()}, {'params' : tag_model.parameters()}], lr=0.001)
triplet_loss = TripletLossFunc(Margin_Distance)

max_v = Margin_Distance * 10
n_epochs = N_Epochs
Lambda = 0.3
min_valid_loss = -1
ten_res = []

In [None]:
if getTenNetFromFile:
    TenNet.getTenModel(tag_model, image_model, name = "best_val.ckpt")
else:
    pp = ProgressPlot(plot_names=["loss", "train distance"],
                      line_names=["train/pos_IT", "valid/pos_II", "0/neg_IT", "0/neg_II"],
                      x_lim=[0, n_epochs-1], 
                      y_lim=[0, max_v])

    pbar = tqdm(range(n_epochs))

    for e in pbar:
    
        loss_dis_train =  TenNet.train(image_model, tag_model, train_loader, triplet_loss, Lambda, optim)
        TenNet.output_loss_dis(f"epoch:{e}: 1-train dataset with train model", loss_dis_train)
        
        loss_dis_valid = TenNet.evalue(image_model, tag_model, valid_loader, triplet_loss, Lambda, optim, e, min_valid_loss, True)   
        TenNet.output_loss_dis(f"epoch:{e}: 2-valid dataset with evalue model", loss_dis_valid)
        min_valid_loss = loss_dis_valid[5]
    
        ten_res.append([loss_dis_train,loss_dis_valid])
        TenNet.updataProgressPlot(pp, loss_dis_train, loss_dis_valid, max_v)
    
    pp.finalize()

In [None]:
TenNet.printResult(ten_res, n_epochs, max_v)

In [None]:
predict_model = D.TagDecoder(train_data.get_tag_num()).to(device)

optim = torch.optim.Adam(predict_model.parameters(), lr=0.001)
loss_funk = F.pairwise_distance 

n_epochs = N_Epochs_Decoder
max_accuracy = -1
threshold = 0.5
decoder_res = []

In [None]:
if getDecoderFromFile:
    getDecoderModel(decoder, name = "decoder_best_val.ckpt")
else:
    pp = ProgressPlot(plot_names=["loss", "mean num of tags", "accuracy"], 
                      line_names=["train/correct", "valid/total"],
                      x_lim=[0, n_epochs-1], 
                      y_lim=[[0,50], [0,20], [0,1]])
    pbar = tqdm(range(n_epochs))

    for e in pbar:
    
        loss_num_train = D.train(predict_model, tag_model, image_model, train_loader, loss_funk, optim, threshold)
        D.output_loss_num(f"epoch:{e}: 1-train dataset with train model", loss_num_train)
        
        loss_num_valid = D.predict(predict_model, tag_model, image_model, valid_loader, loss_funk, optim, threshold, True, max_accuracy, e)   
        D.output_loss_num(f"epoch:{e}: 2-valid dataset with evalue model", loss_num_valid)
        max_accuracy = loss_num_valid[4]
    
        decoder_res.append([loss_num_train,loss_num_valid])
    
        updataProgressPlot(pp, loss_num_train, loss_num_valid)
      
    pp.finalize()

In [None]:
D.printResult(decoder_res, n_epochs)

In [None]:
loader = train_loader
for (x_images,y_tags) in loader:
    x_images, y_tags = x_images.to(device), y_tags.to(device)    
    image_features = image_model(x_images)
    tag_features = tag_model(y_tags)
        
    # in feature space
    IT_dist =  F.pairwise_distance(image_features, tag_features)

    # first triplet loss, an image, cor tag, and a neg image
    anchor_image = image_features
    positive_tag = tag_features

    similarity_matrix = get_similarity_matrix(y_tags)
    z_tag_indexes = get_one_neighbor(tag_features, similarity_matrix, IT_dist + Margin_Distance)
    negative_tag = torch.cat([tag_features[i].view(1,-1) for i in z_tag_indexes])

    z_images_pos, z_images_neg = get_pos_neg(y_tags, similarity_matrix)
    positive_image = torch.cat([image_features[i].view(1,-1) for i in z_images_pos])
    negative_image = torch.cat([image_features[i].view(1,-1) for i in z_images_neg])

    lossIT, dist_image_tag_pos, dist_image_tag_neg = triplet_loss(anchor_image, positive_tag, negative_tag)

    # second triplet loss, an image, a pos image, a neg image
    lossII, dist_image_image_pos, dist_image_image_neg =triplet_loss(anchor_image, positive_image, negative_image)
    loss = lossIT +  Lambda * lossII
    break

In [None]:
print(z_images_pos, z_images_neg)
print(tag_features[0])
print(image_features)