In [1]:
import torch
import torch.nn.functional as F
import torchvision.models as models
from jupyterplot import ProgressPlot
from tqdm.notebook import tqdm

from Define import *
import NUS_WIDE as nus
import TenNet as TenNet
import PredictionModel as P
from TripletLossFunc import TripletLossFunc

In [2]:
getTenNetFromFile = False
getDecoderFromFile = False
test = False

In [3]:
if test:
    Number_Of_Images_Train = int(400/BATCH_SIZE) * BATCH_SIZE

#NUS_WIDE
train_data = nus.NUS_WIDE_Helper(nus.DataSetType.Train_81, Number_Of_Images_Train)
valid_data = nus.NUS_WIDE_Helper(nus.DataSetType.Test_81,  Number_Of_Images_Valid)
#test_data = NUS_WIDE_Helper(DataSetType.Test_81)

batch_size = BATCH_SIZE
train_loader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(valid_data, shuffle=True, batch_size=batch_size)
#test_loader = torch.utils.data.DataLoader(test_data, shuffle=True, batch_size=batch_size)

In [4]:
if test:
    image_model = models.vgg11(pretrained=True).to(device)
else:
    image_model = models.vgg16(pretrained=True).to(device)
    
#image_model = TenNet_Image().to(device)
tag_model = TenNet.TenNet_Tag(train_data.get_tag_list()).to(device)

optim = torch.optim.Adam([{'params' : image_model.parameters()}, {'params' : tag_model.parameters()}], lr=0.0001)
triplet_loss = TripletLossFunc(Margin_Distance)

n_epochs = 20
Lambda = 0.01
min_valid_loss = -1
ten_res = []

In [5]:
if getTenNetFromFile:
    TenNet.getTenModel(tag_model, image_model, name = "SavedModelState/IT_model.ckpt")
else:

    pbar = tqdm(range(n_epochs))

    for e in pbar:
    
        print(f"epoch:{e}:")
        print("1-7 use the same weights, since just 1 updates the weights.\n")
        loss_dis_valid = TenNet.evalue(image_model, tag_model, valid_loader, triplet_loss, Lambda, optim, e, min_valid_loss, True)   
        TenNet.output_loss_dis(f" 2-valid dataset with evalue model", loss_dis_valid)
        min_valid_loss = loss_dis_valid[5]
        '''
        loss_dis_train = TenNet.evalue(image_model, tag_model, train_loader, triplet_loss, Lambda, optim, e, min_valid_loss, True)   
        TenNet.output_loss_dis(f" 3-train dataset with evalue model", loss_dis_train)
        
        print("2-3 use the same way to evalue the model (both eval() and no_grad()), but different data sets.\n")
        
        loss_dis_train = TenNet.evalue2(image_model, tag_model, train_loader, triplet_loss, Lambda, optim)   
        TenNet.output_loss_dis(f" 4-train dataset with evalue model, just evalue model", loss_dis_train)
        
        loss_dis_train = TenNet.evalue3(image_model, tag_model, train_loader, triplet_loss, Lambda, optim)   
        TenNet.output_loss_dis(f" 5-train dataset with evalue model, just no gradient", loss_dis_train)
        
        print("4-5 use the same data set, but different ways to train the model.One uses eval(), one uses no_grad().\n")
        
        loss_dis_train =  TenNet.train(image_model, tag_model, train_loader, triplet_loss, Lambda, optim, False)
        TenNet.output_loss_dis(f" 6-train dataset with train model, but doesn't updata", loss_dis_train)
        
        loss_dis_train =  TenNet.train(image_model, tag_model, train_loader, triplet_loss, Lambda, optim, False)
        TenNet.output_loss_dis(f" 6-train dataset with train model, but doesn't updata", loss_dis_train)
        '''
        
        loss_dis_train =  TenNet.train(image_model, tag_model, train_loader, triplet_loss, Lambda, optim)
        TenNet.output_loss_dis(f" 1-train dataset with train model", loss_dis_train)
        
        print("6-6-1 use the same data set and way to train the model, but 6 doesn't updata the weight.\n")
        
        ten_res.append([loss_dis_train,loss_dis_valid])
        
        if test and e == 2:
            break
        

  0%|          | 0/20 [00:00<?, ?it/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


epoch:0: 1-train dataset with train model
loss: 157.14,  IT_pos_dis: 38.27,  II_pos_dis: 23.87,  IT_neg_dis: 38.60,  II_neg_dis: 23.43
 
epoch:0: 2-valid dataset with evalue model
loss: 65.12,  IT_pos_dis: 10.55,  II_pos_dis: 1.14,  IT_neg_dis: 10.55,  II_neg_dis: 1.14
 
epoch:1: 1-train dataset with train model
loss: 94.64,  IT_pos_dis: 22.63,  II_pos_dis: 9.60,  IT_neg_dis: 22.98,  II_neg_dis: 9.72
 


KeyboardInterrupt: 

In [None]:
res = ten_res
TenNet.printLossLog(res, n_epochs)
TenNet.printLossProgressPlot(res, n_epochs)
TenNet.printDistanceProgressPlot(res, n_epochs, train=True)
TenNet.printDistanceProgressPlot(res, n_epochs, train=False)

In [None]:
predict_model = P.TagDecoder(train_data.get_tag_num()).to(device)

optim = torch.optim.Adam(predict_model.parameters(), lr=0.001)
loss_funk = F.pairwise_distance 

n_epochs = N_Epochs_Decoder
max_accuracy = -1
threshold = 0.5
decoder_res = []

In [None]:
if getDecoderFromFile:
    P.getDecoderModel(predict_model, name = "SavedModelState/decoder_model.ckpt")
else:
    pbar = tqdm(range(n_epochs))

    for e in pbar:
    
        loss_num_train = P.train(predict_model, tag_model, image_model, train_loader, loss_funk, optim, threshold)
        P.output_loss_num(f"epoch:{e}: 1-train dataset with train model", loss_num_train)
        
        loss_num_valid = P.predict(predict_model, tag_model, image_model, valid_loader, loss_funk, optim, threshold, True, max_accuracy, e)   
        P.output_loss_num(f"epoch:{e}: 2-valid dataset with evalue model", loss_num_valid)
        max_accuracy = loss_num_valid[4]
    
        decoder_res.append([loss_num_train,loss_num_valid])

In [None]:

    pp = ProgressPlot(plot_names=["loss", "mean num of tags", "accuracy"], 
                      line_names=["train/correct", "valid/total"],
                      x_lim=[0, n_epochs-1], 
                      y_lim=[[0,50], [0,20], [0,1]])
P.printResult(decoder_res, n_epochs)

In [6]:
print(z_images_pos, z_images_neg)
print(tag_features[0])
print(image_features)

[19, 21, 21, 17, 18, 19, 25, 29, 20, 20, 9, 27, 29, 22, 17, 20, 15, 26, 26, 18, 19, 20, 26, 22, 23, 26, 25, 29, 27, 28, 29, 30] [8, 2, 16, 21, 29, 22, 12, 29, 7, 30, 17, 25, 11, 27, 13, 29, 15, 29, 17, 24, 19, 29, 29, 29, 23, 24, 27, 29, 27, 28, 29, 30]
tensor([ 4.6643e-01,  1.3243e-01,  1.9441e-01, -4.7355e-01, -1.4267e-01,
        -3.4867e-01,  2.6157e-02, -2.7753e-01,  5.1259e-02,  2.3448e-01,
         8.6683e-02,  6.0205e-02, -7.1069e-01, -2.8344e-01, -3.3232e-01,
         1.5805e-01,  2.1986e-01,  1.7727e-01, -1.7434e-01,  6.6996e-01,
        -4.2656e-01,  1.1245e-01, -3.2373e-01, -4.6245e-01, -3.1490e-01,
        -3.7352e-02, -2.8791e-01, -1.7289e-01,  6.5100e-01, -2.4598e-01,
        -2.4802e-01, -7.0213e-01,  2.2615e-01, -5.0316e-01, -7.3564e-01,
        -3.9543e-01, -2.4660e-01, -1.8978e-01,  3.4221e-01, -3.5239e-01,
         1.0371e-01,  2.8671e-01, -2.2580e-02, -1.6176e+00,  1.3838e-01,
        -5.9134e-01,  1.1898e-01,  1.3564e-02, -3.3850e-01, -3.5049e-01,
         2.5211e