In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [2]:
X = np.loadtxt('../data/R300/D14_X.txt')
Y = np.loadtxt('../data/R300/D14_label.txt')
X1 = np.loadtxt('../data/R300/designed_feature/D14_feature.txt')

n = len(Y)
tn = int(n*0.7)
idx = np.random.permutation(n)

In [3]:
len(X1[0])

4906

In [4]:
trainset = torch.utils.data.TensorDataset(torch.from_numpy(X1[idx[0:tn],:]), torch.from_numpy(Y[idx[0:tn]].reshape(-1,1)))
testset = torch.utils.data.TensorDataset(torch.from_numpy(X1[idx[tn:],:]), torch.from_numpy(Y[idx[tn:]].reshape(-1,1)))

In [5]:
class LiNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(4906,2048)
        self.linear2 = nn.Linear(2048,1024)
        self.linear3 = nn.Linear(1024,512)
        self.linear4 = nn.Linear(512,256)
        self.linear = nn.Linear(256,1)
        
    def forward(self, query, x):
        output = F.relu(self.linear1(x))
        output = F.relu(self.linear2(output))
        output = F.relu(self.linear3(output))
        output = F.relu(self.linear4(output))
        return self.linear(output)

In [6]:
batch_size = 128
trainloader = torch.utils.data.DataLoader(trainset, batch_size, drop_last=True, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size, drop_last=True)
query = torch.ones(batch_size, 1, dtype=torch.float64)

In [7]:
net = LiNet()
opt = torch.optim.Adam(net.parameters(), lr=0.0001, weight_decay=0.01, amsgrad=True)
mse = torch.nn.MSELoss()

In [8]:
for epoch in range(50):  # loop over the dataset multiple times
    print('epoch: ', epoch)
    if epoch % 30 == 29:
        for param_group in opt.param_groups:
            param_group['lr'] /= 2
            
    for i, data in tqdm(enumerate(trainloader, 0)):
        inputs, ys = data

        opt.zero_grad()
        outputs = net(query.float(), inputs.float())
        loss = mse(outputs.float(), ys.float())
        loss.backward()
        opt.step()  
        
    with torch.no_grad():
        net.eval()
        train_loss = 0.0
        batch_num = 0
        for data in trainloader:
            inputs, ys = data
            outputs = net(query.float(), inputs.float())
            batch_num += 1
            train_loss += mse(outputs.float(), ys.float()).item()
        train_loss /= batch_num
        print('Train Loss: ', train_loss)
        
        test_loss = 0.0
        batch_num = 0
        for data in testloader:
            inputs, ys = data
            outputs = net(query.float(), inputs.float())
            batch_num += 1
            test_loss += mse(outputs.float(), ys.float()).item()
        test_loss /= batch_num
        print('Test Loss: ', test_loss)
        net.train()
        

0it [00:00, ?it/s]

epoch:  0


44it [00:07,  5.65it/s]


Train Loss:  1.2803505794568495


1it [00:00,  5.75it/s]

Test Loss:  1.339800615059702
epoch:  1


44it [00:07,  5.75it/s]


Train Loss:  1.1783299852501263


1it [00:00,  5.40it/s]

Test Loss:  1.232278861497578
epoch:  2


44it [00:07,  5.76it/s]


Train Loss:  0.9156434549526735


1it [00:00,  5.82it/s]

Test Loss:  0.9546947040055928
epoch:  3


44it [00:07,  5.80it/s]


Train Loss:  0.6751220971345901


1it [00:00,  5.71it/s]

Test Loss:  0.7096235030575803
epoch:  4


44it [00:07,  5.78it/s]


Train Loss:  0.5833873477849093


1it [00:00,  5.75it/s]

Test Loss:  0.619315250923759
epoch:  5


44it [00:07,  5.78it/s]


Train Loss:  0.6372309997677803


1it [00:00,  5.54it/s]

Test Loss:  0.6940104020269293
epoch:  6


44it [00:07,  5.78it/s]


Train Loss:  0.5379894545132463


1it [00:00,  5.16it/s]

Test Loss:  0.5719498819426486
epoch:  7


44it [00:07,  5.73it/s]


Train Loss:  0.4974061447111043


1it [00:00,  5.62it/s]

Test Loss:  0.5492387774743532
epoch:  8


44it [00:07,  5.72it/s]


Train Loss:  0.4642749936743216


1it [00:00,  5.71it/s]

Test Loss:  0.5107846526723159
epoch:  9


44it [00:07,  5.71it/s]


Train Loss:  0.49617264487526636


1it [00:00,  5.78it/s]

Test Loss:  0.5362547557604941
epoch:  10


44it [00:07,  5.77it/s]


Train Loss:  0.4514721632003784


1it [00:00,  5.49it/s]

Test Loss:  0.5048304498195648
epoch:  11


44it [00:07,  5.72it/s]


Train Loss:  0.48173596235838806


1it [00:00,  5.71it/s]

Test Loss:  0.5198956288789448
epoch:  12


44it [00:07,  5.70it/s]


Train Loss:  0.4103998040611094


1it [00:00,  5.73it/s]

Test Loss:  0.4639220410271695
epoch:  13


44it [00:07,  5.68it/s]


Train Loss:  0.4476712881164117


1it [00:00,  5.71it/s]

Test Loss:  0.49081999534054804
epoch:  14


44it [00:07,  5.67it/s]


Train Loss:  0.4010425094853748


1it [00:00,  5.64it/s]

Test Loss:  0.4531414100998326
epoch:  15


44it [00:07,  5.65it/s]


Train Loss:  0.38686026294123044


1it [00:00,  5.48it/s]

Test Loss:  0.44135270777501556
epoch:  16


44it [00:07,  5.63it/s]


Train Loss:  0.3748311000791463


1it [00:00,  5.70it/s]

Test Loss:  0.4347251528187802
epoch:  17


44it [00:07,  5.67it/s]


Train Loss:  0.43298638747497037


1it [00:00,  5.56it/s]

Test Loss:  0.4802869244625694
epoch:  18


44it [00:07,  5.64it/s]


Train Loss:  0.3747642748057842


1it [00:00,  5.67it/s]

Test Loss:  0.43987832414476497
epoch:  19


44it [00:07,  5.67it/s]


Train Loss:  0.42193989726630127


1it [00:00,  5.54it/s]

Test Loss:  0.4985368565509194
epoch:  20


44it [00:07,  5.63it/s]


Train Loss:  0.37767860428853467


1it [00:00,  5.45it/s]

Test Loss:  0.43149080245118393
epoch:  21


44it [00:07,  5.66it/s]


Train Loss:  0.38074695793065155


1it [00:00,  5.49it/s]

Test Loss:  0.45256741266501577
epoch:  22


44it [00:07,  5.67it/s]


Train Loss:  0.38203827901320025


1it [00:00,  5.70it/s]

Test Loss:  0.4350870879072892
epoch:  23


44it [00:07,  5.69it/s]


Train Loss:  0.38371540335091675


1it [00:00,  5.63it/s]

Test Loss:  0.4360551457656057
epoch:  24


44it [00:07,  5.67it/s]


Train Loss:  0.46769597178155725


1it [00:00,  5.71it/s]

Test Loss:  0.5556029347996962
epoch:  25


44it [00:07,  5.65it/s]


Train Loss:  0.3662148327990012


1it [00:00,  5.75it/s]

Test Loss:  0.43958752092562225
epoch:  26


44it [00:07,  5.67it/s]


Train Loss:  0.3569035665555434


1it [00:00,  5.69it/s]

Test Loss:  0.4194726896913428
epoch:  27


44it [00:07,  5.70it/s]


Train Loss:  0.38871327990835364


1it [00:00,  5.75it/s]

Test Loss:  0.4682047241612485
epoch:  28


44it [00:07,  5.66it/s]


Train Loss:  0.3724685223265128


1it [00:00,  5.69it/s]

Test Loss:  0.42879289074947957
epoch:  29


44it [00:07,  5.54it/s]


Train Loss:  0.3510004132986069


1it [00:00,  5.56it/s]

Test Loss:  0.41588932118917765
epoch:  30


44it [00:08,  5.45it/s]


Train Loss:  0.35825667530298233


1it [00:00,  5.10it/s]

Test Loss:  0.43233390858298854
epoch:  31


44it [00:08,  5.33it/s]


Train Loss:  0.3458407914096659


1it [00:00,  5.22it/s]

Test Loss:  0.4176055914477298
epoch:  32


44it [00:08,  5.08it/s]


Train Loss:  0.3458079512823712


1it [00:00,  5.07it/s]

Test Loss:  0.4137911702457227
epoch:  33


44it [00:09,  4.44it/s]


Train Loss:  0.3431067710573023


0it [00:00, ?it/s]

Test Loss:  0.4098980458159196
epoch:  34


44it [00:20,  2.17it/s]


Train Loss:  0.36488682370294223


0it [00:00, ?it/s]

Test Loss:  0.4459237707288642
epoch:  35


44it [00:44,  1.01s/it]


Train Loss:  0.3682500631971793


0it [00:00, ?it/s]

Test Loss:  0.4256601412045328
epoch:  36


44it [01:11,  1.62s/it]


Train Loss:  0.34051186726851895


0it [00:00, ?it/s]

Test Loss:  0.4097025017989309
epoch:  37


44it [01:27,  1.99s/it]


Train Loss:  0.3382303850217299


0it [00:00, ?it/s]

Test Loss:  0.4121368292130922
epoch:  38


44it [01:35,  2.17s/it]


Train Loss:  0.34243404594334687


0it [00:00, ?it/s]

Test Loss:  0.4089476372066297
epoch:  39


44it [01:38,  2.23s/it]


Train Loss:  0.3465461893515153


0it [00:00, ?it/s]

Test Loss:  0.425997290172075
epoch:  40


44it [01:41,  2.31s/it]


Train Loss:  0.35072071105241776


0it [00:00, ?it/s]

Test Loss:  0.4142472853786067
epoch:  41


44it [01:45,  2.39s/it]


Train Loss:  0.39044087583368475


0it [00:00, ?it/s]

Test Loss:  0.4818680255036605
epoch:  42


44it [01:46,  2.41s/it]


Train Loss:  0.3328975842080333


0it [00:00, ?it/s]

Test Loss:  0.40751962128438446
epoch:  43


44it [01:49,  2.49s/it]


Train Loss:  0.3489474274895408


0it [00:00, ?it/s]

Test Loss:  0.41319263922540767
epoch:  44


44it [01:53,  2.58s/it]


Train Loss:  0.3621596490794962


0it [00:00, ?it/s]

Test Loss:  0.4228024451356185
epoch:  45


44it [01:54,  2.59s/it]


Train Loss:  0.35041038488799875


0it [00:00, ?it/s]

Test Loss:  0.4139113049758108
epoch:  46


44it [01:55,  2.63s/it]


Train Loss:  0.3304741189561107


0it [00:00, ?it/s]

Test Loss:  0.40441983938217163
epoch:  47


44it [01:56,  2.65s/it]


Train Loss:  0.40936286747455597


0it [00:00, ?it/s]

Test Loss:  0.4647023144521211
epoch:  48


44it [01:57,  2.68s/it]


Train Loss:  0.43032506040551444


0it [00:00, ?it/s]

Test Loss:  0.4833308489699113
epoch:  49


44it [01:57,  2.67s/it]


Train Loss:  0.3571048886938529
Test Loss:  0.4206127367521587


In [14]:
torch.save(net.state_dict(), "./LiNet.p")