In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import copy
import csv

In [2]:
from utils import Data_Read
from Config import Config
from Model import Intere_Gene_Model
%load_ext autoreload
%autoreload 2
config = Config()

In [3]:
train_list_np, test_x, test_y = Data_Read(config)
train_loader_list = []
for train_data in train_list_np:
    train_set = TensorDataset(torch.tensor(train_data))
    train_loader = DataLoader(dataset=train_set, batch_size=config.batch_size,
                          shuffle=True, num_workers=0)
    train_loader_list.append(train_loader)
test_set = TensorDataset(torch.tensor(test_x), torch.tensor(test_y))
test_loader = DataLoader(dataset=test_set, batch_size=config.batch_size,
                          shuffle=False, num_workers=0)

model = Intere_Gene_Model(config).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr = config.lr)

In [None]:
print(config.model_type)

Recall_list = defaultdict(list)
HitRate_list = defaultdict(list)

for epoch in range(config.epochs):
    model.train()
    total_loss = 0
    
    '''*********************  training  **************************'''
    for train_loader in tqdm(train_loader_list):
        for data in train_loader:
            data = data[0].cuda()
            optimizer.zero_grad()
            loss = model(data)
            total_loss += loss*1e-6
            loss.backward()
            optimizer.step()
    print("epoch: %d   loss: %10.4f"%(epoch+1,total_loss))
    
    '''*********************  testing  **************************'''
    model.eval()
    test_num = 0
    Recall = defaultdict(float)
    HitRate = defaultdict(float)
    for x, y in tqdm(test_loader):
        x = x.cuda()
        y = y.numpy()
        recall_topN = model.serving(x)
        
        test_num += x.shape[0]
        for N in config.N_list:
            this_N = recall_topN[:,:N]
            for i in range(y.shape[0]):
                Recall[N] += (np.intersect1d(this_N[i],y[i]).shape[0])/y[i].shape[0]
                HitRate[N] += np.intersect1d(this_N[i],y[i]).shape[0] > 0
    
    for N in config.N_list:
        Recall[N] /= test_num
        HitRate[N] /= test_num
        print("Recall%d :  %10.4f    HitRate%d :  %10.4f"%(N,Recall[N],N,HitRate[N]))
        Recall_list[N].append(copy.deepcopy(Recall[N]))
        HitRate_list[N].append(copy.deepcopy(HitRate[N]))

Intere_Gene v1


100%|██████████| 20/20 [01:07<00:00,  3.37s/it]


epoch: 1   loss:   -21.3167


100%|██████████| 61/61 [00:15<00:00,  4.06it/s]


Recall20 :      0.0017    HitRate20 :      0.0073
Recall50 :      0.0026    HitRate50 :      0.0119


100%|██████████| 20/20 [01:07<00:00,  3.35s/it]


epoch: 2   loss: -13751.4648


100%|██████████| 61/61 [00:15<00:00,  3.91it/s]


Recall20 :      0.0043    HitRate20 :      0.0171
Recall50 :      0.0058    HitRate50 :      0.0237


100%|██████████| 20/20 [01:08<00:00,  3.40s/it]


epoch: 3   loss: -166150.6406


100%|██████████| 61/61 [00:15<00:00,  4.02it/s]


Recall20 :      0.0090    HitRate20 :      0.0350
Recall50 :      0.0123    HitRate50 :      0.0480


100%|██████████| 20/20 [01:07<00:00,  3.39s/it]


epoch: 4   loss: -772874.5000


100%|██████████| 61/61 [00:14<00:00,  4.24it/s]


Recall20 :      0.0119    HitRate20 :      0.0471
Recall50 :      0.0174    HitRate50 :      0.0676


100%|██████████| 20/20 [01:06<00:00,  3.35s/it]


epoch: 5   loss: -2417177.0000


100%|██████████| 61/61 [00:15<00:00,  3.98it/s]


Recall20 :      0.0122    HitRate20 :      0.0472
Recall50 :      0.0204    HitRate50 :      0.0773


100%|██████████| 20/20 [01:07<00:00,  3.40s/it]


epoch: 6   loss: -5784604.0000


100%|██████████| 61/61 [00:15<00:00,  3.83it/s]


Recall20 :      0.0127    HitRate20 :      0.0477
Recall50 :      0.0213    HitRate50 :      0.0805


100%|██████████| 20/20 [01:11<00:00,  3.56s/it]


epoch: 7   loss: -11583180.0000


100%|██████████| 61/61 [00:16<00:00,  3.65it/s]


Recall20 :      0.0121    HitRate20 :      0.0449
Recall50 :      0.0210    HitRate50 :      0.0790


100%|██████████| 20/20 [01:10<00:00,  3.51s/it]


epoch: 8   loss: -20470730.0000


100%|██████████| 61/61 [00:16<00:00,  3.70it/s]


Recall20 :      0.0127    HitRate20 :      0.0476
Recall50 :      0.0219    HitRate50 :      0.0823


100%|██████████| 20/20 [01:04<00:00,  3.22s/it]


epoch: 9   loss: -33491968.0000


100%|██████████| 61/61 [00:16<00:00,  3.81it/s]


Recall20 :      0.0122    HitRate20 :      0.0444
Recall50 :      0.0220    HitRate50 :      0.0800


100%|██████████| 20/20 [01:07<00:00,  3.37s/it]


epoch: 10   loss: -51680616.0000


100%|██████████| 61/61 [00:16<00:00,  3.80it/s]


Recall20 :      0.0125    HitRate20 :      0.0453
Recall50 :      0.0221    HitRate50 :      0.0786


100%|██████████| 20/20 [01:10<00:00,  3.52s/it]


epoch: 11   loss: -75494768.0000


100%|██████████| 61/61 [00:17<00:00,  3.49it/s]


Recall20 :      0.0128    HitRate20 :      0.0464
Recall50 :      0.0227    HitRate50 :      0.0809


100%|██████████| 20/20 [01:10<00:00,  3.51s/it]


epoch: 12   loss: -106482024.0000


100%|██████████| 61/61 [00:15<00:00,  4.04it/s]


Recall20 :      0.0129    HitRate20 :      0.0467
Recall50 :      0.0221    HitRate50 :      0.0790


100%|██████████| 20/20 [01:08<00:00,  3.42s/it]


epoch: 13   loss: -144621056.0000


100%|██████████| 61/61 [00:15<00:00,  3.94it/s]


Recall20 :      0.0136    HitRate20 :      0.0510
Recall50 :      0.0231    HitRate50 :      0.0823


100%|██████████| 20/20 [01:10<00:00,  3.52s/it]


epoch: 14   loss: -191776416.0000


100%|██████████| 61/61 [00:16<00:00,  3.78it/s]


Recall20 :      0.0138    HitRate20 :      0.0513
Recall50 :      0.0236    HitRate50 :      0.0838


100%|██████████| 20/20 [01:10<00:00,  3.54s/it]


epoch: 15   loss: -249551696.0000


100%|██████████| 61/61 [00:16<00:00,  3.70it/s]


Recall20 :      0.0139    HitRate20 :      0.0512
Recall50 :      0.0235    HitRate50 :      0.0823


100%|██████████| 20/20 [01:09<00:00,  3.45s/it]


epoch: 16   loss: -317349536.0000


100%|██████████| 61/61 [00:16<00:00,  3.69it/s]


Recall20 :      0.0138    HitRate20 :      0.0498
Recall50 :      0.0239    HitRate50 :      0.0843


100%|██████████| 20/20 [01:08<00:00,  3.41s/it]


epoch: 17   loss: -396848992.0000


100%|██████████| 61/61 [00:16<00:00,  3.69it/s]


Recall20 :      0.0139    HitRate20 :      0.0501
Recall50 :      0.0232    HitRate50 :      0.0827


100%|██████████| 20/20 [01:10<00:00,  3.53s/it]


epoch: 18   loss: -496329184.0000


100%|██████████| 61/61 [00:16<00:00,  3.70it/s]


Recall20 :      0.0140    HitRate20 :      0.0510
Recall50 :      0.0239    HitRate50 :      0.0841


100%|██████████| 20/20 [01:11<00:00,  3.56s/it]


epoch: 19   loss: -612641600.0000


100%|██████████| 61/61 [00:16<00:00,  3.79it/s]


Recall20 :      0.0142    HitRate20 :      0.0513
Recall50 :      0.0235    HitRate50 :      0.0826


100%|██████████| 20/20 [01:11<00:00,  3.58s/it]


epoch: 20   loss: -748626752.0000


100%|██████████| 61/61 [00:15<00:00,  3.81it/s]


Recall20 :      0.0146    HitRate20 :      0.0522
Recall50 :      0.0242    HitRate50 :      0.0854


100%|██████████| 20/20 [01:09<00:00,  3.49s/it]


epoch: 21   loss: -909762624.0000


100%|██████████| 61/61 [00:16<00:00,  3.75it/s]


Recall20 :      0.0142    HitRate20 :      0.0512
Recall50 :      0.0238    HitRate50 :      0.0838


100%|██████████| 20/20 [01:08<00:00,  3.42s/it]


epoch: 22   loss: -1098870016.0000


100%|██████████| 61/61 [00:16<00:00,  3.77it/s]


Recall20 :      0.0142    HitRate20 :      0.0514
Recall50 :      0.0237    HitRate50 :      0.0843


100%|██████████| 20/20 [01:09<00:00,  3.47s/it]


epoch: 23   loss: -1316180608.0000


 69%|██████▉   | 42/61 [00:11<00:05,  3.62it/s]

In [None]:
torch.save(model, model_save_path + model_name + ".pth") 
torch.save(model.state_dict(),model_save_path + model_name + "_para.pth")

In [None]:
with open (ans_save_path + model_name +".csv", 'w', encoding='utf8') as f:
    writer  = csv.writer(f)
    for N in config.N_list:
        writer.writerow(Recall_list[N])
        writer.writerow(HitRate_list[N])

In [1]:
import random

In [2]:
l = [2,6,3,5,3,9,0,3]

In [4]:
random.sample(l,3)

[6, 2, 5]