In [31]:
import torch
import pandas as pd
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset

from src.config import CLEANED_PATH, DataType, UNCLEANED_PATH, RAW_FILE, INFERENCE_FILE


class MultiNewsDataset(Dataset):
    
    def __init__(self, data_type):
        if data_type == DataType.train.name:
            train_path = CLEANED_PATH.format(data_type=DataType.train.name) + INFERENCE_FILE        
            train_df = pd.read_csv(train_path, sep=',')
            train_df.drop(columns=['pair_id'], inplace=True)


            l_cols = ['entities', 'narrative', 'time', 'geography', 'overall']
            y = train_df[l_cols]
            train_df.drop(columns=l_cols, inplace=True)

            y = y.values
            x = train_df.values
        else:
            test_path = CLEANED_PATH.format(data_type=DataType.test.name) + INFERENCE_FILE
            test_df = pd.read_csv(test_path, sep=',', 
                                  usecols=['sentences_mean', 'sentences_min', 'sentences_max',
                                           'sentences_med', 'title', 'n1_title_n2_text',
                                           'n2_title_n1_text', 'n1_title_n1_text',
                                           'n2_title_n2_text', 'start_para', 'end_para',
                                           'ner', 'tf_idf', 'wmd_dist'])
            x = test_df.values
            y = 0
        
        self.x = torch.tensor(x, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)

    def __len__(self):
        return len(self.y)

    def __getitem__(self,idx):
        return self.x[idx], self.y[idx]


class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(14, 22),
            nn.ReLU(),
            nn.Linear(22, 14),
            nn.ReLU()
        )
        
        self.bn = nn.Sequential(nn.Linear(14, 4), nn.ReLU())
        
        self.decoder = nn.Sequential(
            nn.Linear(4, 14),
            nn.ReLU(),
        )
    
    def forward(self, x):
        x = self.encoder(x)
        
        out = self.bn(x)
        
        recon = self.decoder(out)
        return out, recon


class TaskRegressor(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(4, 1), nn.ReLU())
        
    def forward(self, x):
        return self.mlp(x)


ae = Autoencoder()
geo = TaskRegressor()
ent = TaskRegressor()
tim = TaskRegressor()
nar = TaskRegressor()
ovr = TaskRegressor()

# MSE Loss 
loss = nn.MSELoss()
optimizer = optim.Adam((list(ae.parameters())+list(geo.parameters())+list(ent.parameters())+
                        list(tim.parameters())+list(nar.parameters())+list(ovr.parameters())),lr=0.01)
max_epoch = 300
lmbda = 0.01

myDs = MultiNewsDataset(DataType.train.name)
train_loader = DataLoader(myDs, batch_size=100, shuffle=False)

for ep in range(max_epoch):
    for data in train_loader:
        optimizer.zero_grad()
        x, y = data
        out, recon = ae(x)

        # Encoder-Decoder forward pass
        recon_loss = loss(x, recon)
        
#         'entities', 'narrative', 'time', 'geography', 'overall'
        gold_data_ent, gold_data_nar, gold_data_tim, gold_data_geo, gold_data_ovr = y[:][0], y[:][1], y[:][2], y[:][3], y[:][4]
        # Regressor forward-pass
        geo_loss = loss(geo(out), gold_data_geo)
        ent_loss = loss(ent(out), gold_data_ent)
        tim_loss = loss(tim(out), gold_data_tim)
        nar_loss = loss(nar(out), gold_data_nar)
        ovr_loss = loss(ovr(out), gold_data_ovr)

        # Accumulate Loss
#         main_loss = recon_loss + 0.5*(geo_loss+ent_loss+tim_loss+nar_loss+ovr_loss)
        main_loss = 0.5*recon_loss + 0.5*(ovr_loss)

        main_loss.backward()  # Compute gradients
        optimizer.step()  # Back-prop
    print(f'epoch [{ep + 1}/{max_epoch}], loss:{main_loss.item(): .4f}')

#sentences_mean,sentences_min,sentences_max,sentences_med,
#title,n1_title_n2_text,n2_title_n1_text,n1_title_n1_text,n2_title_n2_text,
#start_para,end_para,ner,tf_idf,wmd_dist,


#geography,entities,time,narrative,overall

epoch [1/300], loss: 0.5186
epoch [2/300], loss: 0.4796
epoch [3/300], loss: 0.4975
epoch [4/300], loss: 0.5245
epoch [5/300], loss: 0.5399
epoch [6/300], loss: 0.5597
epoch [7/300], loss: 0.5726
epoch [8/300], loss: 0.5820
epoch [9/300], loss: 0.5631
epoch [10/300], loss: 0.5794
epoch [11/300], loss: 0.5964
epoch [12/300], loss: 0.5925
epoch [13/300], loss: 0.5867
epoch [14/300], loss: 0.5787
epoch [15/300], loss: 0.5722
epoch [16/300], loss: 0.5642
epoch [17/300], loss: 0.5577
epoch [18/300], loss: 0.5678
epoch [19/300], loss: 0.5445
epoch [20/300], loss: 0.5413
epoch [21/300], loss: 0.5367
epoch [22/300], loss: 0.5318
epoch [23/300], loss: 0.5266
epoch [24/300], loss: 0.5214
epoch [25/300], loss: 0.5164
epoch [26/300], loss: 0.5119
epoch [27/300], loss: 0.5078
epoch [28/300], loss: 0.5041
epoch [29/300], loss: 0.5008
epoch [30/300], loss: 0.4979
epoch [31/300], loss: 0.4954
epoch [32/300], loss: 0.4932
epoch [33/300], loss: 0.4911
epoch [34/300], loss: 0.4891
epoch [35/300], loss: 0

epoch [280/300], loss: 0.4182
epoch [281/300], loss: 0.4177
epoch [282/300], loss: 0.4175
epoch [283/300], loss: 0.4177
epoch [284/300], loss: 0.4173
epoch [285/300], loss: 0.4174
epoch [286/300], loss: 0.4172
epoch [287/300], loss: 0.4171
epoch [288/300], loss: 0.4159
epoch [289/300], loss: 0.4177
epoch [290/300], loss: 0.4166
epoch [291/300], loss: 0.4185
epoch [292/300], loss: 0.4168
epoch [293/300], loss: 0.4180
epoch [294/300], loss: 0.4162
epoch [295/300], loss: 0.4151
epoch [296/300], loss: 0.4166
epoch [297/300], loss: 0.4357
epoch [298/300], loss: 0.4151
epoch [299/300], loss: 0.4159
epoch [300/300], loss: 0.4142


In [32]:
testDS = MultiNewsDataset(DataType.test.name)
test_loader = DataLoader(testDS, batch_size=100, shuffle=False)

for data in train_loader:
    torch.no_grad()
    x, _ = data
    out, _ = ae(x)
    pred = ovr(out)
    print(pred)

tensor([[2.7717],
        [2.4511],
        [2.4479],
        [2.4478],
        [2.4477],
        [2.4472],
        [2.6818],
        [2.5732],
        [2.4478],
        [2.4503],
        [2.7383],
        [2.4509],
        [2.4515],
        [2.6488],
        [2.4505],
        [2.4800],
        [2.5862],
        [2.4497],
        [2.4515],
        [2.4477],
        [2.5730],
        [2.4612],
        [2.5219],
        [2.4471],
        [2.6001],
        [2.4487],
        [2.5077],
        [2.6563],
        [2.4472],
        [2.4507],
        [2.4489],
        [2.4879],
        [2.4493],
        [2.6574],
        [2.5560],
        [2.5662],
        [2.4507],
        [2.4476],
        [2.4519],
        [2.4479],
        [2.4472],
        [2.5142],
        [2.4477],
        [2.4471],
        [2.4644],
        [2.4477],
        [2.4467],
        [2.4468],
        [2.4492],
        [2.4515],
        [2.4478],
        [2.5024],
        [2.4479],
        [2.4467],
        [2.4514],
        [2