### Salary prediction, episode II: make it actually work (4 points)

Your main task is to use some of the tricks you've learned on the network and analyze if you can improve __validation MAE__. Try __at least 3 options__ from the list below for a passing grade. Write a short report about what you have tried. More ideas = more bonus points. 

__Please be serious:__ " plot learning curves in MAE/epoch, compare models based on optimal performance, test one change at a time. You know the drill :)

You can use either __pytorch__ or __tensorflow__ or any other framework (e.g. pure __keras__). Feel free to adapt the seminar code for your needs. For tensorflow version, consider `seminar_tf2.ipynb` as a starting point.


In [1]:
# Reuse of data handling routines

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import nltk
from collections import Counter
from sklearn.feature_extraction import DictVectorizer
from sklearn.model_selection import train_test_split

data = pd.read_csv("week02_classification/Train_rev1.csv", index_col=None)
data['Log1pSalary'] = np.log1p(data['SalaryNormalized']).astype('float32')
text_columns = ["Title", "FullDescription"]
categorical_columns = ["Category", "Company", "LocationNormalized", "ContractType", "ContractTime"]
TARGET_COLUMN = "Log1pSalary"
data[categorical_columns] = data[categorical_columns].fillna('NaN') # cast missing values to string "NaN"
tokenizer = nltk.tokenize.WordPunctTokenizer()

def space_div(text):
    return ' '.join(tokenizer.tokenize(str(text).lower()))

data["Title"] = data["Title"].apply(space_div)
data["FullDescription"] = data["FullDescription"].apply(space_div)
token_counts = Counter()

for i in range(len(data)):
    token_counts.update(data["Title"][i].split(' '))
    token_counts.update(data["FullDescription"][i].split(' '))

tokens = [key for key, value in token_counts.items() if value >= 10]
UNK, PAD = "UNK", "PAD"
tokens = [UNK, PAD] + tokens

token_to_id = {tokens[i]: i for i in range(len(tokens))}
UNK_IX, PAD_IX = map(token_to_id.get, [UNK, PAD])

def as_matrix(sequences, max_len=None):
    """ Convert a list of tokens into a matrix with padding """
    if isinstance(sequences[0], str):
        sequences = list(map(str.split, sequences))
        
    max_len = min(max(map(len, sequences)), max_len or float('inf'))
    
    matrix = np.full((len(sequences), max_len), np.int32(PAD_IX))
    for i, seq in enumerate(sequences):
        row_ix = [token_to_id.get(word, UNK_IX) for word in seq[:max_len]]
        matrix[i, :len(row_ix)] = row_ix
    
    return matrix

top_companies, top_counts = zip(*Counter(data['Company']).most_common(1000))
recognized_companies = set(top_companies)
data["Company"] = data["Company"].apply(lambda comp: comp if comp in recognized_companies else "Other")

categorical_vectorizer = DictVectorizer(dtype=np.float32, sparse=False)
categorical_vectorizer.fit(data[categorical_columns].apply(dict, axis=1))

data_train, data_val = train_test_split(data, test_size=0.2, random_state=42)
data_train.index = range(len(data_train))
data_val.index = range(len(data_val))

print("Train size = ", len(data_train))
print("Validation size = ", len(data_val))


Train size =  195814
Validation size =  48954


In [2]:
import gensim.downloader as api
embeddings = api.load('glove-wiki-gigaword-300');

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


device = 'cuda' if torch.cuda.is_available() else 'cpu'


def to_tensors(batch, device):
    batch_tensors = dict()
    for key, arr in batch.items():
        if key in ["FullDescription", "Title"]:
            batch_tensors[key] = torch.tensor(arr, device=device, dtype=torch.int64)
        else:
            batch_tensors[key] = torch.tensor(arr, device=device)
    return batch_tensors


def make_batch(data, max_len=None, word_dropout=0, device=device):
    """
    Creates a keras-friendly dict from the batch data.
    :param word_dropout: replaces token index with UNK_IX with this probability
    :returns: a dict with {'title' : int64[batch, title_max_len]
    """
    batch = {}
    batch["Title"] = as_matrix(data["Title"].values, max_len)
    batch["FullDescription"] = as_matrix(data["FullDescription"].values, max_len)
    batch['Categorical'] = categorical_vectorizer.transform(data[categorical_columns].apply(dict, axis=1))
    
    if word_dropout != 0:
        batch["FullDescription"] = apply_word_dropout(batch["FullDescription"], 1. - word_dropout)
    
    if TARGET_COLUMN in data.columns:
        batch[TARGET_COLUMN] = data[TARGET_COLUMN].values
    
    return to_tensors(batch, device)

def apply_word_dropout(matrix, keep_prop, replace_with=UNK_IX, pad_ix=PAD_IX,):
    dropout_mask = np.random.choice(2, np.shape(matrix), p=[keep_prop, 1 - keep_prop])
    dropout_mask &= matrix != pad_ix
    return np.choose(dropout_mask, [matrix, np.full_like(matrix, replace_with)])

#### 1. (A & C) Applying CNN tricks & using pretrained word2vec model

In [4]:
example_batch = make_batch(data_train[:3], max_len=10)
example_batch

{'Title': tensor([[  320,    89,  1657,     1,     1,     1,     1],
         [ 5130,   130,    25,   173,    14,   562, 21820],
         [ 2246,    42,  1433,   109,  9312,  9313,   116]], device='cuda:0'),
 'FullDescription': tensor([[  320,    89,  1657,  2890,    44,   320,    89,  1657,    12,  2890],
         [ 5130,   130,    25,   173,    14,   562, 21820,  1334,   129,     8],
         [   49,    66,   444,    11,    12,    37,   576,    44,    42,  1433]],
        device='cuda:0'),
 'Categorical': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 'Log1pSalary': tensor([ 9.7115, 10.4631, 10.7144], device='cuda:0')}

In [6]:
# validating dimensions

emb = nn.Embedding(num_embeddings=len(tokens), embedding_dim=128).to('cuda')
x = emb(example_batch['Title'])
print('Initial title input size:', x.shape)
x = x.transpose(2, 1)
text_conv_title_1 = nn.Conv1d(in_channels=128, out_channels=16, kernel_size=3).to('cuda')
x_1 = text_conv_title_1(x)
print('After title conv1:', x_1.shape)
x_1 = torch.amax(x_1, dim=2)
print('After pooling:', x_1.shape)

text_conv_title_2 = nn.Conv1d(in_channels=128, out_channels=16, kernel_size=2).to('cuda')

x_2 = text_conv_title_2(x)
print('After title conv2:', x_2.shape)
x_2 = torch.amax(x_2, dim=2)
print('After pooling:', x_2.shape)


text_conv_descr_1 = nn.Conv1d(in_channels=128, out_channels=32, kernel_size=3, padding=1).to('cuda')
text_conv_descr_1_2 = nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3).to('cuda')
y = emb(example_batch['FullDescription'])
print('Initial descr input size:', y.shape)
y = y.transpose(2, 1)
y_1 = F.leaky_relu(text_conv_descr_1(y))
print('After descr conv1:', y_1.shape)
y_1 = text_conv_descr_1_2(y_1)
print('After descr conv1_2:', y_1.shape)
y_1 = torch.amax(y_1, dim=2)
print('After pooling', y_1.shape)

text_conv_descr_2 = nn.Conv1d(in_channels=128, out_channels=32, kernel_size=2, padding=1).to('cuda')
text_conv_descr_2_2 = nn.Conv1d(in_channels=32, out_channels=16, kernel_size=2).to('cuda')

text_conv_descr_3 = nn.Conv1d(in_channels=128, out_channels=32, kernel_size=4, padding=2).to('cuda')
text_conv_descr_3_2 = nn.Conv1d(in_channels=32, out_channels=16, kernel_size=4).to('cuda')


y_2 = F.leaky_relu(text_conv_descr_2(y))
print('After descr conv2:', y_2.shape)
y_2 = text_conv_descr_2_2(y_2)
print('After descr conv2_2:' , y_2.shape)
y_2 = torch.amax(y_2, dim=2)
print('After pooling', y_2.shape)

y_3 = F.leaky_relu(text_conv_descr_3(y))
print('After descr conv3:', y_3.shape)
y_3 = text_conv_descr_3_2(y_3)
print('After descr conv3_2:', y_3.shape)
y_3 = torch.amax(y_3, dim=2)
print('After pooling', y_3.shape)


fc_categorical = nn.Linear(len(categorical_vectorizer.vocabulary_), 16).to('cuda')
z = fc_categorical(example_batch['Categorical'])
print(z.shape)

u = torch.cat([x_1, y_1, x_2, y_2, y_3, z], dim=1)
print(u.shape)

final_fc_input_dim = 16 * 6
fc = nn.Linear(final_fc_input_dim, 1).to('cuda')

u = fc(u)
print(u.shape)

u = u.view(-1)
print(u.shape)


Initial title input size: torch.Size([3, 7, 128])
After title conv1: torch.Size([3, 16, 5])
After pooling: torch.Size([3, 16])
After title conv2: torch.Size([3, 16, 6])
After pooling: torch.Size([3, 16])
Initial descr input size: torch.Size([3, 10, 128])
After descr conv1: torch.Size([3, 32, 10])
After descr conv1_2: torch.Size([3, 16, 8])
After pooling torch.Size([3, 16])
After descr conv2: torch.Size([3, 32, 11])
After descr conv2_2: torch.Size([3, 16, 10])
After pooling torch.Size([3, 16])
After descr conv3: torch.Size([3, 32, 11])
After descr conv3_2: torch.Size([3, 16, 8])
After pooling torch.Size([3, 16])
torch.Size([3, 16])
torch.Size([3, 96])
torch.Size([3, 1])
torch.Size([3])


Architectural Modifications:
1. Increased embedding dimension to enhance feature representation.
2. Introduced additional convolutional layers with varying kernel sizes for both the title and full description fields.
3. Added a novel convolution-over-convolution layer to the convolutional stack for the full description field.
4. Added batch normalization layers after each convolution layer.

In [5]:
class SalaryPredictor(nn.Module):
    def __init__(self, n_tokens=len(tokens), n_cat_features=len(categorical_vectorizer.vocabulary_), embedding_dim=300):
        super().__init__()
        # text convolution kernels
        self.kernel_size_1 = 3
        self.kernel_size_2 = 2
        self.kernel_size_3 = 4
        self.embedding_dim = embedding_dim
        self.conv_out_dim_title = 8
        self.conv_out_dim_descr_1 = 32
        self.conv_out_dim_descr_2 = 8
        self.emb = nn.Embedding(num_embeddings=n_tokens, embedding_dim=self.embedding_dim)
        self.text_conv_title_1 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=self.conv_out_dim_title, kernel_size=self.kernel_size_1)
        self.text_conv_title_2 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=self.conv_out_dim_title, kernel_size=self.kernel_size_2)

        self.text_conv_descr_1_1 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=self.conv_out_dim_descr_1, 
                                            kernel_size=self.kernel_size_1, padding=self.kernel_size_1 // 2)
        self.text_conv_descr_2_1 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=self.conv_out_dim_descr_1, 
                                            kernel_size=self.kernel_size_2, padding=self.kernel_size_2 // 2)
        self.text_conv_descr_3_1 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=self.conv_out_dim_descr_1,
                                            kernel_size=self.kernel_size_3, padding=self.kernel_size_3 // 2)

        self.text_conv_descr_1_2 = nn.Conv1d(in_channels=self.conv_out_dim_descr_1, out_channels=self.conv_out_dim_descr_2, kernel_size=self.kernel_size_1)
        self.text_conv_descr_2_2 = nn.Conv1d(in_channels=self.conv_out_dim_descr_1, out_channels=self.conv_out_dim_descr_2, kernel_size=self.kernel_size_2)
        self.text_conv_descr_3_2 = nn.Conv1d(in_channels=self.conv_out_dim_descr_1, out_channels=self.conv_out_dim_descr_2, kernel_size=self.kernel_size_3)

        self.bn_title_1 = nn.BatchNorm1d(self.conv_out_dim_title)
        self.bn_title_2 = nn.BatchNorm1d(self.conv_out_dim_title)

        self.bn_descr_1_1 = nn.BatchNorm1d(self.conv_out_dim_descr_1)
        self.bn_descr_2_1 = nn.BatchNorm1d(self.conv_out_dim_descr_1)
        self.bn_descr_3_1 = nn.BatchNorm1d(self.conv_out_dim_descr_1)

        self.bn_descr_1_2 = nn.BatchNorm1d(self.conv_out_dim_descr_2)
        self.bn_descr_2_2 = nn.BatchNorm1d(self.conv_out_dim_descr_2)
        self.bn_descr_3_2 = nn.BatchNorm1d(self.conv_out_dim_descr_2)
        
        # categorical kernels
        self.categorical_out_dim = 8
        self.fc_categorical = nn.Linear(n_cat_features, self.categorical_out_dim)

        # final linear layer
        self.final_fc_input_dim = 2 * self.conv_out_dim_title \
                + 3 * self.conv_out_dim_descr_2 \
                + self.categorical_out_dim
        self.fc = nn.Linear(self.final_fc_input_dim, 1)
        
    def forward(self, batch):
        x = self.emb(batch['Title'])
        x = x.transpose(2, 1)
        y = self.emb(batch['FullDescription'])
        y = y.transpose(2, 1)

        x_1 = torch.amax(self.bn_title_1(self.text_conv_title_1(x)), dim=2)
        x_2 = torch.amax(self.bn_title_2(self.text_conv_title_2(x)), dim=2)

        y_1 = F.leaky_relu(self.bn_descr_1_1(self.text_conv_descr_1_1(y)))
        y_2 = F.leaky_relu(self.bn_descr_2_1(self.text_conv_descr_2_1(y)))
        y_3 = F.leaky_relu(self.bn_descr_3_1(self.text_conv_descr_3_1(y)))

        y_1 = torch.amax(self.bn_descr_1_2(self.text_conv_descr_1_2(y_1)), dim=2)
        y_2 = torch.amax(self.bn_descr_2_2(self.text_conv_descr_2_2(y_2)), dim=2)
        y_3 = torch.amax(self.bn_descr_3_2(self.text_conv_descr_3_2(y_3)), dim=2)

        z = F.leaky_relu(self.fc_categorical(batch['Categorical']))
        u = torch.cat([x_1, x_2, y_1, y_2, y_3, z], dim=1)
        u = self.fc(u).view(-1)
        return u

In [5]:
def iterate_minibatches(data, batch_size=256, shuffle=True, cycle=False, device=device, **kwargs):
    """ iterates minibatches of data in random order """
    while True:
        indices = np.arange(len(data))
        if shuffle:
            indices = np.random.permutation(indices)

        for start in range(0, len(indices), batch_size):
            batch = make_batch(data.iloc[indices[start : start + batch_size]], device=device, **kwargs)
            yield batch
        
        if not cycle: break

In [6]:
from tqdm.auto import tqdm

BATCH_SIZE = 32
EPOCHS = 25

In [7]:
def print_metrics(model, data, batch_size=BATCH_SIZE, name="", device=torch.device('cpu'), **kw):
    squared_error = abs_error = num_samples = abs_error_in_dollars = 0.0
    model.eval()
    with torch.no_grad():
        for batch in iterate_minibatches(data, batch_size=batch_size, shuffle=False, device=device, **kw):
            batch_pred = model(batch)
            squared_error += torch.sum(torch.square(batch_pred - batch[TARGET_COLUMN]))
            abs_error += torch.sum(torch.abs(batch_pred - batch[TARGET_COLUMN]))
            abs_error_in_dollars += torch.sum(torch.abs(torch.exp(batch_pred) - torch.exp(batch[TARGET_COLUMN])))
            num_samples += len(batch_pred)
    mse = squared_error.detach().cpu().numpy() / num_samples
    mae = abs_error.detach().cpu().numpy() / num_samples
    mae_in_dollars = abs_error_in_dollars.detach().cpu().numpy() / num_samples
    print("%s results:" % (name or ""))
    print("Mean square error: %.5f" % mse)
    print("Mean absolute error: %.5f" % mae)
    print("Mean absolute error in dollars: %.5f" % mae_in_dollars)
    return mse, mae


Let us set pretrained weights for the embedding layer.

In [12]:
model = SalaryPredictor()
embedding_initialization = model.emb.weight.detach().numpy()
pretrained_weights_embedding = torch.Tensor(np.array([
    embeddings[tokens[i]] if tokens[i] in embeddings else embedding_initialization[i]
    for i in range(len(tokens))
]))
print(np.array([0 if tokens[i] in embeddings else 1 for i in range(len(tokens))]).sum(), 'words out of', len(tokens), 'are missing in the word2vec model.')
model.emb.weight.data = pretrained_weights_embedding
model = model.to(device)

7935 words out of 34158 are missing in the word2vec model.


In [12]:
criterion = nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

for epoch in range(EPOCHS):
    print(f"epoch: {epoch}")
    model.train()
    for i, batch in tqdm(enumerate(
            iterate_minibatches(data_train, batch_size=BATCH_SIZE, device=device, word_dropout=0.4)),
            total=len(data_train) // BATCH_SIZE
        ):
        pred = model(batch)
        loss = criterion(pred, batch[TARGET_COLUMN])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print_metrics(model, data_val, device=device)

epoch: 0


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 12.46571
Mean absolute error: 3.28146
Mean absolute error in dollars: 31791.58230
epoch: 1


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 11.20480
Mean absolute error: 3.07166
Mean absolute error in dollars: 31160.59779
epoch: 2


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 9.91846
Mean absolute error: 2.87556
Mean absolute error in dollars: 30675.04809
epoch: 3


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 9.05937
Mean absolute error: 2.73119
Mean absolute error in dollars: 30184.39188
epoch: 4


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 9.56734
Mean absolute error: 2.81528
Mean absolute error in dollars: 30464.64583
epoch: 5


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 7.61628
Mean absolute error: 2.46056
Mean absolute error in dollars: 29060.13645
epoch: 6


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 8.69138
Mean absolute error: 2.66482
Mean absolute error in dollars: 29925.46864
epoch: 7


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 8.12885
Mean absolute error: 2.56366
Mean absolute error in dollars: 29570.73759
epoch: 8


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 7.52018
Mean absolute error: 2.44183
Mean absolute error in dollars: 28940.68391
epoch: 9


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 6.33237
Mean absolute error: 2.23062
Mean absolute error in dollars: 28160.83147
epoch: 10


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 6.01529
Mean absolute error: 2.16508
Mean absolute error in dollars: 27819.85897
epoch: 11


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 6.53499
Mean absolute error: 2.27276
Mean absolute error in dollars: 28397.06696
epoch: 12


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 5.71921
Mean absolute error: 2.08218
Mean absolute error in dollars: 27145.72407
epoch: 13


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 5.71206
Mean absolute error: 2.09650
Mean absolute error in dollars: 27399.12244
epoch: 14


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 5.53881
Mean absolute error: 2.04402
Mean absolute error in dollars: 26932.39596
epoch: 15


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 6.07152
Mean absolute error: 2.18017
Mean absolute error in dollars: 27951.14565
epoch: 16


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 5.49952
Mean absolute error: 2.04798
Mean absolute error in dollars: 27120.16799
epoch: 17


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 5.05250
Mean absolute error: 1.94784
Mean absolute error in dollars: 26516.49663
epoch: 18


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 4.84193
Mean absolute error: 1.89608
Mean absolute error in dollars: 26150.87012
epoch: 19


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 4.79059
Mean absolute error: 1.89116
Mean absolute error in dollars: 26225.36553
epoch: 20


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 4.97415
Mean absolute error: 1.93749
Mean absolute error in dollars: 26518.76096
epoch: 21


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 4.65344
Mean absolute error: 1.85494
Mean absolute error in dollars: 25902.24619
epoch: 22


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 4.29639
Mean absolute error: 1.76472
Mean absolute error in dollars: 25226.66536
epoch: 23


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 4.61974
Mean absolute error: 1.85464
Mean absolute error in dollars: 25999.43980
epoch: 24


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 3.99493
Mean absolute error: 1.68822
Mean absolute error in dollars: 24618.05744


#### 2. Playing with pooling layer

Key changes:
1. Maxpooling is now combined with a custom 5-head attention layer.
2. Dropout layer is applied on concatenated pooling vectors.

In [8]:
text_conv_title_1 = nn.Conv1d(in_channels=300, out_channels=32, kernel_size=3).to('cuda')
bn_title_1 = nn.BatchNorm1d(32).to('cuda')
attn_title_1 = nn.Linear(32, 5).to('cuda')
emb = nn.Embedding(num_embeddings=len(tokens), embedding_dim=300).to('cuda')
batch_size_ = example_batch['Title'].shape[0]

x = emb(example_batch["Title"])
x = x.transpose(2, 1)
x = text_conv_title_1(x)
x = bn_title_1(x)
print(x.shape)
attn_x = attn_title_1(x.transpose(2, 1))
print(attn_x.shape)
attn_x = nn.Softmax(dim=1)(attn_x)
print(torch.sum(attn_x, dim=1))

# attn_x = nn.Softmax(dim=1)(attn_title_1(x.transpose(2, 1)))
print(attn_x)
x_attnpool = torch.matmul(x, attn_x).view(batch_size_, -1)
print(x_attnpool.shape)
x_maxpool = torch.amax(x, dim=2)
print(x_maxpool.shape)

torch.Size([3, 32, 5])
torch.Size([3, 5, 5])
tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]], device='cuda:0',
       grad_fn=<SumBackward1>)
tensor([[[0.1407, 0.1360, 0.2362, 0.1227, 0.3049],
         [0.3332, 0.1631, 0.2909, 0.1743, 0.1889],
         [0.1536, 0.2668, 0.1345, 0.2918, 0.1360],
         [0.1862, 0.2171, 0.1692, 0.2056, 0.1851],
         [0.1862, 0.2171, 0.1692, 0.2056, 0.1851]],

        [[0.4090, 0.1719, 0.1776, 0.1842, 0.1571],
         [0.0830, 0.1629, 0.2221, 0.1723, 0.2630],
         [0.0816, 0.2105, 0.0640, 0.2373, 0.1845],
         [0.3833, 0.2580, 0.0957, 0.1591, 0.0427],
         [0.0432, 0.1967, 0.4405, 0.2472, 0.3527]],

        [[0.1341, 0.2415, 0.1290, 0.4327, 0.1425],
         [0.2168, 0.0720, 0.2154, 0.1778, 0.2960],
         [0.2365, 0.2495, 0.2605, 0.1499, 0.1954],
         [0.0875, 0.2601, 0.2849, 0.1227, 0.1815],
         [0.3251, 0.1770, 0.1102, 0.11

In [8]:
class SalaryPredictor(nn.Module):
    def __init__(self, n_tokens=len(tokens), n_cat_features=len(categorical_vectorizer.vocabulary_), embedding_dim=300):
        super().__init__()
        # text convolution kernels
        self.kernel_size_1 = 3
        self.kernel_size_2 = 2
        self.kernel_size_3 = 4
        self.embedding_dim = embedding_dim
        self.conv_out_dim_title = 8
        self.conv_out_dim_descr_1 = 32
        self.conv_out_dim_descr_2 = 8
        self.emb = nn.Embedding(num_embeddings=n_tokens, embedding_dim=self.embedding_dim)
        self.text_conv_title_1 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=self.conv_out_dim_title, kernel_size=self.kernel_size_1)
        self.text_conv_title_2 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=self.conv_out_dim_title, kernel_size=self.kernel_size_2)

        self.text_conv_descr_1_1 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=self.conv_out_dim_descr_1, 
                                            kernel_size=self.kernel_size_1, padding=self.kernel_size_1 // 2)
        self.text_conv_descr_2_1 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=self.conv_out_dim_descr_1, 
                                            kernel_size=self.kernel_size_2, padding=self.kernel_size_2 // 2)
        self.text_conv_descr_3_1 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=self.conv_out_dim_descr_1,
                                            kernel_size=self.kernel_size_3, padding=self.kernel_size_3 // 2)

        self.text_conv_descr_1_2 = nn.Conv1d(in_channels=self.conv_out_dim_descr_1, out_channels=self.conv_out_dim_descr_2, kernel_size=self.kernel_size_1)
        self.text_conv_descr_2_2 = nn.Conv1d(in_channels=self.conv_out_dim_descr_1, out_channels=self.conv_out_dim_descr_2, kernel_size=self.kernel_size_2)
        self.text_conv_descr_3_2 = nn.Conv1d(in_channels=self.conv_out_dim_descr_1, out_channels=self.conv_out_dim_descr_2, kernel_size=self.kernel_size_3)

        self.bn_title_1 = nn.BatchNorm1d(self.conv_out_dim_title)
        self.bn_title_2 = nn.BatchNorm1d(self.conv_out_dim_title)

        self.bn_descr_1_1 = nn.BatchNorm1d(self.conv_out_dim_descr_1)
        self.bn_descr_2_1 = nn.BatchNorm1d(self.conv_out_dim_descr_1)
        self.bn_descr_3_1 = nn.BatchNorm1d(self.conv_out_dim_descr_1)

        self.bn_descr_1_2 = nn.BatchNorm1d(self.conv_out_dim_descr_2)
        self.bn_descr_2_2 = nn.BatchNorm1d(self.conv_out_dim_descr_2)
        self.bn_descr_3_2 = nn.BatchNorm1d(self.conv_out_dim_descr_2)

        self.num_attn_heads = 5
        
        self.attn_title_1 = nn.Linear(self.conv_out_dim_title, self.num_attn_heads)
        self.attn_title_2 = nn.Linear(self.conv_out_dim_title, self.num_attn_heads)

        self.attn_descr_1 = nn.Linear(self.conv_out_dim_descr_2, self.num_attn_heads)
        self.attn_descr_2 = nn.Linear(self.conv_out_dim_descr_2, self.num_attn_heads)
        self.attn_descr_3 = nn.Linear(self.conv_out_dim_descr_2, self.num_attn_heads)
        
        # categorical kernels
        self.categorical_out_dim = 8
        self.fc_categorical = nn.Linear(n_cat_features, self.categorical_out_dim)

        # final linear layer
        self.final_fc_input_dim = 2 * (self.num_attn_heads + 1) * self.conv_out_dim_title \
                + 3 * (self.num_attn_heads + 1) * self.conv_out_dim_descr_2 \
                + self.categorical_out_dim
        self.fc = nn.Linear(self.final_fc_input_dim, 1)

        # additional considerations
        self.dropout = nn.Dropout(p=0.3)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, batch):
        batch_size = batch['Title'].shape[0]
        x = self.emb(batch['Title'])
        x = x.transpose(2, 1)
        y = self.emb(batch['FullDescription'])
        y = y.transpose(2, 1)

        x_1_pre_pooling = self.bn_title_1(self.text_conv_title_1(x)) 
        x_2_pre_pooling = self.bn_title_2(self.text_conv_title_2(x))

        x_1_maxpool = torch.amax(x_1_pre_pooling, dim=2)
        title_1_attn = self.softmax(self.attn_title_1(x_1_pre_pooling.transpose(2, 1)))
        x_1_attnpool = torch.matmul(x_1_pre_pooling, title_1_attn).view(batch_size, -1)

        x_2_maxpool = torch.amax(x_2_pre_pooling, dim=2)
        title_2_attn = self.softmax(self.attn_title_2(x_2_pre_pooling.transpose(2, 1)))
        x_2_attnpool = torch.matmul(x_2_pre_pooling, title_2_attn).view(batch_size, -1)

        y_1 = F.leaky_relu(self.bn_descr_1_1(self.text_conv_descr_1_1(y)))
        y_2 = F.leaky_relu(self.bn_descr_2_1(self.text_conv_descr_2_1(y)))
        y_3 = F.leaky_relu(self.bn_descr_3_1(self.text_conv_descr_3_1(y)))

        y_1_pre_pooling = self.bn_descr_1_2(self.text_conv_descr_1_2(y_1))
        y_2_pre_pooling = self.bn_descr_2_2(self.text_conv_descr_2_2(y_2))
        y_3_pre_pooling = self.bn_descr_3_2(self.text_conv_descr_3_2(y_3))
        y_1_maxpool = torch.amax(y_1_pre_pooling, dim=2)
        descr_1_attn = self.softmax(self.attn_descr_1(y_1_pre_pooling.transpose(2, 1)))
        y_1_attnpool = torch.matmul(y_1_pre_pooling, descr_1_attn).view(batch_size, -1)
        y_2_maxpool = torch.amax(y_2_pre_pooling, dim=2)
        descr_2_attn = self.softmax(self.attn_descr_2(y_2_pre_pooling.transpose(2, 1)))
        y_2_attnpool = torch.matmul(y_2_pre_pooling, descr_2_attn).view(batch_size, -1)
        y_3_maxpool = torch.amax(y_3_pre_pooling, dim=2)
        descr_3_attn = self.softmax(self.attn_descr_3(y_3_pre_pooling.transpose(2, 1)))
        y_3_attnpool = torch.matmul(y_3_pre_pooling, descr_3_attn).view(batch_size, -1)

        z = F.leaky_relu(self.fc_categorical(batch['Categorical']))
        # print(x_1_maxpool.shape)
        # print(x_2_maxpool.shape)
        # print(y_1_maxpool.shape)
        # print(y_2_maxpool.shape)
        # print(y_3_maxpool.shape)
        # print(x_1_attnpool.shape)
        # print(x_2_attnpool.shape)
        # print(y_1_attnpool.shape)
        # print(y_2_attnpool.shape)
        # print(y_3_attnpool.shape)
        # print(z.shape)
        u = torch.cat([x_1_maxpool, x_2_maxpool, y_1_maxpool, y_2_maxpool, y_3_maxpool,
                        x_1_attnpool, x_2_attnpool, y_1_attnpool, y_2_attnpool, y_3_attnpool, z], dim=1)
        u = self.dropout(u)
        u = self.fc(u).view(-1)
        return u

In [10]:
model = SalaryPredictor()
embedding_initialization = model.emb.weight.detach().numpy()
pretrained_weights_embedding = torch.Tensor(np.array([
    embeddings[tokens[i]] if tokens[i] in embeddings else embedding_initialization[i]
    for i in range(len(tokens))
]))
missing_embeddings = [tokens[i] for i in range(len(tokens)) if tokens[i] not in embeddings]
print(len(missing_embeddings), 'words out of', len(tokens), 'are missing in the word2vec model.')
model.emb.weight.data = pretrained_weights_embedding
model = model.to(device)

7935 words out of 34158 are missing in the word2vec model.


In [25]:
missing_embeddings[200:220]

['preemployment',
 'svq',
 'upport',
 'scswis',
 'jobholder',
 'postregistration',
 'wellregarded',
 'stourportonsevern',
 '(****)',
 'nvq4',
 '95pm',
 'purposebuilt',
 'selfconfidence',
 'compliances',
 'pova',
 'closeknit',
 'pmld',
 'docare',
 '’.',
 'disciplinaries']

In [30]:
missing_embeddings[-200:-100]

['bonusabout',
 'isotrak',
 'fastcgi',
 'coodrinator',
 'dcucd',
 'dcda',
 'mwrc',
 'campaigncbmsuk',
 'grouptrader',
 'samworthchurchacademy',
 'bioanalyst',
 '9001en',
 '****…',
 'deljaaappointments',
 'dipfa',
 'cmrt',
 'electonics',
 'financiallymotivated',
 'leathergoods',
 'cloudfoundry',
 'voypic',
 'adminp',
 'autotype',
 'bakewellprojectresource',
 'eccuk',
 'cwdm',
 'generalmanagerdesignate_job',
 'hollymere',
 'rexs',
 'seanlejgroup',
 'axisweb',
 'lovc',
 'yearsmarket',
 'ynysybwl',
 'glyncoch',
 'personfor',
 'perioperatively',
 'rabaiotti',
 '>>>',
 'developmentsales',
 'janecarerecruitmentuk',
 'carerecruitmentuk',
 'sgarande4socialwork',
 '◊',
 'yoen',
 'knowldege',
 'bh1',
 'mediahawk',
 'fernandeshays',
 'dunhumby',
 'bestequipped',
 'adops',
 'cila',
 'percepta',
 'ruolo',
 'phenome',
 'chromatographymass',
 'pcse',
 'choiceconsultants',
 'tarnjeet',
 'pottermore',
 'jobstrspersonnel',
 'highcapital',
 'greeneking',
 'oteyour',
 'aqmen',
 'grateley',
 'striling',
 'l

In [26]:
embeddings["pre-employment"]

array([ 2.0319e-01, -1.7955e-01,  1.3091e-01,  4.0578e-02, -5.8469e-02,
       -1.3209e-01,  1.8004e-01, -5.0834e-02, -8.8997e-02,  6.2717e-01,
       -2.9239e-01,  1.1365e-01,  1.1766e-01, -1.2064e-01, -8.1639e-02,
        3.8736e-03,  3.2311e-01, -2.8180e-01,  2.8452e-01,  1.7923e-01,
       -6.2219e-02, -5.6736e-01, -7.9436e-01, -1.2483e-01,  2.4960e-01,
       -6.3184e-01, -1.5098e-01, -7.6473e-02,  1.6859e-01,  9.4797e-02,
       -3.3555e-01, -3.6312e-01,  5.9160e-01, -1.9709e-01,  7.0158e-01,
        1.3215e-01, -2.5647e-01,  9.3904e-02,  4.7473e-01, -3.3133e-01,
       -5.0672e-02,  4.6579e-02,  5.4687e-01, -1.1571e+00,  1.5673e-01,
       -4.9834e-01, -1.6967e-01, -1.8722e-01,  6.1965e-01, -5.9124e-01,
       -5.6397e-01, -2.5553e-01,  6.0134e-01,  1.8448e-01, -5.2633e-01,
       -5.0148e-01,  5.9354e-01, -3.1188e-01, -7.8267e-01,  9.1548e-02,
       -2.5027e-01,  1.5435e-01, -7.8129e-01, -1.5020e-01, -4.2698e-02,
        6.6598e-01, -1.1652e-01, -1.0867e-01,  1.6089e-01,  3.81

Some embeddings could be obtained if typos and misprints are corrected in the original text.

In [9]:
from tqdm.auto import tqdm

BATCH_SIZE = 32
EPOCHS = 50

A weight decay is added to the SGD optimizer to improve generalization.

In [10]:
lrs = [1e-6, 1e-5, 1e-4]
weight_decays = [0, 1e-4, 1e-2]
momentums = [0, 0.05, 0.1, 0.5, 0.9]

In [11]:
initial_weights_path = 'week02_classification/initial_weight_path.pt'

In [16]:
state_dict_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
torch.save(state_dict_cpu, initial_weights_path)

In [18]:
import itertools

criterion = nn.MSELoss(reduction='sum')

def train_and_evaluate(lr, weight_decay, momentum):
    print(f'lr = {lr}, weight_decay = {weight_decay}, momentum = {momentum}')
    model = SalaryPredictor()
    model.load_state_dict(torch.load(initial_weights_path))
    model = model.to('cuda')
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum)
    for epoch in range(6):
        print(f"epoch: {epoch}")
        model.train()
        for i, batch in tqdm(enumerate(
                iterate_minibatches(data_train, batch_size=BATCH_SIZE, device=device, word_dropout=0.4)),
                total=len(data_train) // BATCH_SIZE):
            pred = model(batch)
            loss = criterion(pred, batch[TARGET_COLUMN])
            optimizer.zero_grad()
            loss.backward()
            # clipping the gradient to avoid nans
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
    return get_mse(model, data_val, device=device)

results = []
for lr, weight_decay, momentum in itertools.product(lrs, weight_decays, momentums):
    mse = train_and_evaluate(lr, weight_decay, momentum)
    print(mse)
    results.append((lr, weight_decay, momentum, mse))

best_hyperparams = min(results, key=lambda x: x[3])  # Find the combination with the lowest MSE

print("Grid search results:")
for lr, weight_decay, momentum, mse in results:
    print(f"lr={lr}, weight_decay={weight_decay}, momentum={momentum}, mse={mse}")

print("\nBest hyperparameters:")
print(f"lr={best_hyperparams[0]}, weight_decay={best_hyperparams[1]}, momentum={best_hyperparams[2]}, mse={best_hyperparams[3]}")


lr = 1e-06, weight_decay = 0, momentum = 0
epoch: 0


  0%|          | 0/6119 [00:00<?, ?it/s]

In [20]:
best_hyperparams

(0.0001, 0.01, 0.1, 6.144037642480696)

In [16]:
BATCH_SIZE = 32
EPOCHS = 50

In [13]:
def get_mse(model, data, batch_size=BATCH_SIZE, name="", device=torch.device('cpu'), **kw):
    squared_error = 0
    num_samples = 0
    with torch.no_grad():
        for batch in iterate_minibatches(data, batch_size=batch_size, shuffle=False, device=device, **kw):
            batch_pred = model(batch)
            squared_error += torch.sum(torch.square(batch_pred - batch[TARGET_COLUMN]))
            num_samples += len(batch_pred)
    mse = squared_error.detach().cpu().numpy() / num_samples
    return mse

In [18]:
criterion = nn.MSELoss(reduction='sum')
# The best set of optimizer hyperparameters
model = SalaryPredictor()
# The loaded weights are a combination of randomly initialized layers
# and a pretrained embedding layer
model.load_state_dict(torch.load(initial_weights_path))
model = model.to('cuda')
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, weight_decay=0.01, momentum=0.1)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.8, patience=5)


for epoch in range(EPOCHS):
    print(f"epoch: {epoch}")
    model.train()
    for i, batch in tqdm(enumerate(
            iterate_minibatches(data_train, batch_size=BATCH_SIZE, device=device, word_dropout=0.4)),
            total=len(data_train) // BATCH_SIZE
        ):
        optimizer.zero_grad()
        pred = model(batch)
        loss = criterion(pred, batch[TARGET_COLUMN])
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # might be reduntant
        optimizer.step()
        
    print_metrics(model, data_val, device=device)
    val_mse = get_mse(model, data_val, device=device)
    scheduler.step(val_mse)

epoch: 0


  model.load_state_dict(torch.load(initial_weights_path))


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 2.31705
Mean absolute error: 1.42878
Mean absolute error in dollars: 26188.51395
epoch: 1


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.84487
Mean absolute error: 0.77102
Mean absolute error in dollars: 18326.17527
epoch: 2


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.91749
Mean absolute error: 0.83047
Mean absolute error in dollars: 19339.06802
epoch: 3


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 1.49372
Mean absolute error: 1.12176
Mean absolute error in dollars: 23032.44090
epoch: 4


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 1.17862
Mean absolute error: 0.98571
Mean absolute error in dollars: 21414.97340
epoch: 5


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 1.26297
Mean absolute error: 1.04174
Mean absolute error in dollars: 22342.02427
epoch: 6


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.87892
Mean absolute error: 0.84727
Mean absolute error in dollars: 19894.89954
epoch: 7


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.58841
Mean absolute error: 0.67937
Mean absolute error in dollars: 17335.45778
epoch: 8


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.57172
Mean absolute error: 0.66180
Mean absolute error in dollars: 16799.30057
epoch: 9


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.52877
Mean absolute error: 0.64716
Mean absolute error in dollars: 16705.72832
epoch: 10


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.35837
Mean absolute error: 0.51367
Mean absolute error in dollars: 14335.57250
epoch: 11


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.40764
Mean absolute error: 0.54765
Mean absolute error in dollars: 14836.88230
epoch: 12


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.38691
Mean absolute error: 0.53654
Mean absolute error in dollars: 14663.00478
epoch: 13


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.31456
Mean absolute error: 0.48052
Mean absolute error in dollars: 13799.39698
epoch: 14


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.27193
Mean absolute error: 0.43824
Mean absolute error in dollars: 12668.77346
epoch: 15


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.18210
Mean absolute error: 0.34384
Mean absolute error in dollars: 10507.17523
epoch: 16


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.24543
Mean absolute error: 0.41672
Mean absolute error in dollars: 12283.52233
epoch: 17


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.16267
Mean absolute error: 0.32177
Mean absolute error in dollars: 10125.71018
epoch: 18


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.20467
Mean absolute error: 0.37498
Mean absolute error in dollars: 11422.29129
epoch: 19


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09234
Mean absolute error: 0.22880
Mean absolute error in dollars: 7843.77563
epoch: 20


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.11526
Mean absolute error: 0.26166
Mean absolute error in dollars: 8663.42509
epoch: 21


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.12943
Mean absolute error: 0.28285
Mean absolute error in dollars: 9106.64608
epoch: 22


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09261
Mean absolute error: 0.23001
Mean absolute error in dollars: 7782.19095
epoch: 23


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09679
Mean absolute error: 0.23608
Mean absolute error in dollars: 7943.77089
epoch: 24


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.15470
Mean absolute error: 0.31510
Mean absolute error in dollars: 9913.53842
epoch: 25


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.12850
Mean absolute error: 0.28157
Mean absolute error in dollars: 9172.99931
epoch: 26


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.10036
Mean absolute error: 0.24212
Mean absolute error in dollars: 8130.07346
epoch: 27


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.11159
Mean absolute error: 0.25841
Mean absolute error in dollars: 8566.98648
epoch: 28


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08305
Mean absolute error: 0.21681
Mean absolute error in dollars: 7495.29632
epoch: 29


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08201
Mean absolute error: 0.21499
Mean absolute error in dollars: 7391.77579
epoch: 30


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.10251
Mean absolute error: 0.24568
Mean absolute error in dollars: 8102.43870
epoch: 31


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09267
Mean absolute error: 0.23129
Mean absolute error in dollars: 7740.64469
epoch: 32


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.10553
Mean absolute error: 0.25075
Mean absolute error in dollars: 8201.76067
epoch: 33


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08683
Mean absolute error: 0.22239
Mean absolute error in dollars: 7590.92961
epoch: 34


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.10902
Mean absolute error: 0.25561
Mean absolute error in dollars: 8500.61364
epoch: 35


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.11206
Mean absolute error: 0.26015
Mean absolute error in dollars: 8606.32267
epoch: 36


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08245
Mean absolute error: 0.21651
Mean absolute error in dollars: 7714.34931
epoch: 37


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09959
Mean absolute error: 0.24221
Mean absolute error in dollars: 8127.49667
epoch: 38


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08197
Mean absolute error: 0.21501
Mean absolute error in dollars: 7400.36835
epoch: 39


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08016
Mean absolute error: 0.21238
Mean absolute error in dollars: 7324.30020
epoch: 40


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08658
Mean absolute error: 0.22231
Mean absolute error in dollars: 7547.05822
epoch: 41


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.10227
Mean absolute error: 0.24605
Mean absolute error in dollars: 8219.04907
epoch: 42


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08681
Mean absolute error: 0.22264
Mean absolute error in dollars: 7573.28954
epoch: 43


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08086
Mean absolute error: 0.21333
Mean absolute error in dollars: 7324.10671
epoch: 44


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07967
Mean absolute error: 0.21112
Mean absolute error in dollars: 7263.47771
epoch: 45


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09194
Mean absolute error: 0.23077
Mean absolute error in dollars: 7826.97553
epoch: 46


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08167
Mean absolute error: 0.21503
Mean absolute error in dollars: 7350.77926
epoch: 47


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07557
Mean absolute error: 0.20474
Mean absolute error in dollars: 7099.17784
epoch: 48


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08616
Mean absolute error: 0.22210
Mean absolute error in dollars: 7551.51432
epoch: 49


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07736
Mean absolute error: 0.20792
Mean absolute error in dollars: 7166.52204


In [19]:
# let us run for yet another 50 epochs
for epoch in range(EPOCHS):
    print(f"epoch: {epoch}")
    model.train()
    for i, batch in tqdm(enumerate(
            iterate_minibatches(data_train, batch_size=BATCH_SIZE, device=device, word_dropout=0.4)),
            total=len(data_train) // BATCH_SIZE
        ):
        optimizer.zero_grad()
        pred = model(batch)
        loss = criterion(pred, batch[TARGET_COLUMN])
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # might be reduntant
        optimizer.step()
        
    print_metrics(model, data_val, device=device)
    val_mse = get_mse(model, data_val, device=device)
    scheduler.step(val_mse)

epoch: 0


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08082
Mean absolute error: 0.21347
Mean absolute error in dollars: 7277.87229
epoch: 1


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08065
Mean absolute error: 0.21465
Mean absolute error in dollars: 7619.09515
epoch: 2


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08391
Mean absolute error: 0.21834
Mean absolute error in dollars: 7442.22184
epoch: 3


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07709
Mean absolute error: 0.20761
Mean absolute error in dollars: 7171.60502
epoch: 4


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07452
Mean absolute error: 0.20320
Mean absolute error in dollars: 7072.99882
epoch: 5


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08010
Mean absolute error: 0.21228
Mean absolute error in dollars: 7320.72133
epoch: 6


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07979
Mean absolute error: 0.21172
Mean absolute error in dollars: 7274.81767
epoch: 7


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07456
Mean absolute error: 0.20342
Mean absolute error in dollars: 7051.41251
epoch: 8


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08490
Mean absolute error: 0.22020
Mean absolute error in dollars: 7477.05879
epoch: 9


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07452
Mean absolute error: 0.20329
Mean absolute error in dollars: 7044.81954
epoch: 10


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09221
Mean absolute error: 0.23197
Mean absolute error in dollars: 7803.65306
epoch: 11


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07473
Mean absolute error: 0.20358
Mean absolute error in dollars: 7068.55252
epoch: 12


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07335
Mean absolute error: 0.20196
Mean absolute error in dollars: 7004.30347
epoch: 13


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08046
Mean absolute error: 0.21341
Mean absolute error in dollars: 7289.28480
epoch: 14


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08112
Mean absolute error: 0.21422
Mean absolute error in dollars: 7339.30008
epoch: 15


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07331
Mean absolute error: 0.20186
Mean absolute error in dollars: 7020.85909
epoch: 16


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07349
Mean absolute error: 0.20228
Mean absolute error in dollars: 7050.74642
epoch: 17


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08037
Mean absolute error: 0.21340
Mean absolute error in dollars: 7282.80623
epoch: 18


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08261
Mean absolute error: 0.21689
Mean absolute error in dollars: 7393.10144
epoch: 19


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08288
Mean absolute error: 0.21690
Mean absolute error in dollars: 7500.20869
epoch: 20


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07371
Mean absolute error: 0.20222
Mean absolute error in dollars: 7001.51751
epoch: 21


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08192
Mean absolute error: 0.21588
Mean absolute error in dollars: 7367.16297
epoch: 22


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07375
Mean absolute error: 0.20245
Mean absolute error in dollars: 7000.65073
epoch: 23


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07417
Mean absolute error: 0.20293
Mean absolute error in dollars: 7018.04437
epoch: 24


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07608
Mean absolute error: 0.20600
Mean absolute error in dollars: 7100.70940
epoch: 25


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07660
Mean absolute error: 0.20695
Mean absolute error in dollars: 7103.40123
epoch: 26


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08101
Mean absolute error: 0.21480
Mean absolute error in dollars: 7303.10675
epoch: 27


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09063
Mean absolute error: 0.23011
Mean absolute error in dollars: 7743.99804
epoch: 28


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08164
Mean absolute error: 0.21555
Mean absolute error in dollars: 7371.30204
epoch: 29


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07871
Mean absolute error: 0.21098
Mean absolute error in dollars: 7214.75998
epoch: 30


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09236
Mean absolute error: 0.23343
Mean absolute error in dollars: 7783.25122
epoch: 31


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07758
Mean absolute error: 0.20901
Mean absolute error in dollars: 7171.95081
epoch: 32


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08183
Mean absolute error: 0.21580
Mean absolute error in dollars: 7413.70658
epoch: 33


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07193
Mean absolute error: 0.19904
Mean absolute error in dollars: 6932.27634
epoch: 34


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07749
Mean absolute error: 0.20859
Mean absolute error in dollars: 7165.85399
epoch: 35


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08278
Mean absolute error: 0.21771
Mean absolute error in dollars: 7410.61731
epoch: 36


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08499
Mean absolute error: 0.22161
Mean absolute error in dollars: 7483.18634
epoch: 37


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.10008
Mean absolute error: 0.24547
Mean absolute error in dollars: 8155.50337
epoch: 38


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08267
Mean absolute error: 0.21749
Mean absolute error in dollars: 7388.03677
epoch: 39


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08463
Mean absolute error: 0.22084
Mean absolute error in dollars: 7501.01532
epoch: 40


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07316
Mean absolute error: 0.20148
Mean absolute error in dollars: 6997.53597
epoch: 41


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07396
Mean absolute error: 0.20270
Mean absolute error in dollars: 7003.77203
epoch: 42


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07845
Mean absolute error: 0.21060
Mean absolute error in dollars: 7199.43457
epoch: 43


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08404
Mean absolute error: 0.21962
Mean absolute error in dollars: 7515.87858
epoch: 44


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07153
Mean absolute error: 0.19887
Mean absolute error in dollars: 6911.46072
epoch: 45


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08244
Mean absolute error: 0.21721
Mean absolute error in dollars: 7375.35809
epoch: 46


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08318
Mean absolute error: 0.21810
Mean absolute error in dollars: 7472.01503
epoch: 47


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09764
Mean absolute error: 0.24191
Mean absolute error in dollars: 8035.13176
epoch: 48


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.11343
Mean absolute error: 0.26615
Mean absolute error in dollars: 8685.10226
epoch: 49


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07450
Mean absolute error: 0.20371
Mean absolute error in dollars: 7048.46901


In [20]:
# let us run for yet another 100 epochs
for epoch in range(100):
    print(f"epoch: {epoch}")
    model.train()
    for i, batch in tqdm(enumerate(
            iterate_minibatches(data_train, batch_size=BATCH_SIZE, device=device, word_dropout=0.4)),
            total=len(data_train) // BATCH_SIZE
        ):
        optimizer.zero_grad()
        pred = model(batch)
        loss = criterion(pred, batch[TARGET_COLUMN])
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # might be reduntant
        optimizer.step()
        
    print_metrics(model, data_val, device=device)
    val_mse = get_mse(model, data_val, device=device)
    scheduler.step(val_mse)

epoch: 0


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08378
Mean absolute error: 0.21963
Mean absolute error in dollars: 7416.47751
epoch: 1


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08055
Mean absolute error: 0.21409
Mean absolute error in dollars: 7295.18683
epoch: 2


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08687
Mean absolute error: 0.22442
Mean absolute error in dollars: 7583.59668
epoch: 3


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07567
Mean absolute error: 0.20609
Mean absolute error in dollars: 7035.79229
epoch: 4


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09316
Mean absolute error: 0.23481
Mean absolute error in dollars: 7887.47510
epoch: 5


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07630
Mean absolute error: 0.20701
Mean absolute error in dollars: 7129.13184
epoch: 6


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09271
Mean absolute error: 0.23438
Mean absolute error in dollars: 7874.23753
epoch: 7


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07349
Mean absolute error: 0.20197
Mean absolute error in dollars: 6997.76933
epoch: 8


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08341
Mean absolute error: 0.21906
Mean absolute error in dollars: 7445.72162
epoch: 9


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08442
Mean absolute error: 0.22096
Mean absolute error in dollars: 7446.86163
epoch: 10


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09384
Mean absolute error: 0.23626
Mean absolute error in dollars: 7894.54263
epoch: 11


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07333
Mean absolute error: 0.20200
Mean absolute error in dollars: 6993.98194
epoch: 12


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08320
Mean absolute error: 0.21842
Mean absolute error in dollars: 7510.52564
epoch: 13


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08712
Mean absolute error: 0.22534
Mean absolute error in dollars: 7568.62361
epoch: 14


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07794
Mean absolute error: 0.20999
Mean absolute error in dollars: 7146.64902
epoch: 15


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07465
Mean absolute error: 0.20426
Mean absolute error in dollars: 7049.95155
epoch: 16


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08342
Mean absolute error: 0.21891
Mean absolute error in dollars: 7428.83458
epoch: 17


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07781
Mean absolute error: 0.20955
Mean absolute error in dollars: 7159.00609
epoch: 18


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09882
Mean absolute error: 0.24430
Mean absolute error in dollars: 8128.09348
epoch: 19


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08019
Mean absolute error: 0.21390
Mean absolute error in dollars: 7281.51456
epoch: 20


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08974
Mean absolute error: 0.22948
Mean absolute error in dollars: 7736.96254
epoch: 21


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07184
Mean absolute error: 0.19961
Mean absolute error in dollars: 6920.97234
epoch: 22


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07654
Mean absolute error: 0.20755
Mean absolute error in dollars: 7123.65077
epoch: 23


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07438
Mean absolute error: 0.20406
Mean absolute error in dollars: 7022.82077
epoch: 24


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07182
Mean absolute error: 0.19932
Mean absolute error in dollars: 6898.00155
epoch: 25


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08200
Mean absolute error: 0.21697
Mean absolute error in dollars: 7366.26940
epoch: 26


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07417
Mean absolute error: 0.20342
Mean absolute error in dollars: 7011.00952
epoch: 27


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08250
Mean absolute error: 0.21750
Mean absolute error in dollars: 7426.61535
epoch: 28


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07306
Mean absolute error: 0.20175
Mean absolute error in dollars: 6969.54561
epoch: 29


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08331
Mean absolute error: 0.21905
Mean absolute error in dollars: 7410.11595
epoch: 30


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07784
Mean absolute error: 0.20963
Mean absolute error in dollars: 7160.71022
epoch: 31


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07433
Mean absolute error: 0.20369
Mean absolute error in dollars: 7027.94231
epoch: 32


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07356
Mean absolute error: 0.20249
Mean absolute error in dollars: 6981.09213
epoch: 33


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08849
Mean absolute error: 0.22767
Mean absolute error in dollars: 7688.91939
epoch: 34


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08225
Mean absolute error: 0.21721
Mean absolute error in dollars: 7373.06827
epoch: 35


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08267
Mean absolute error: 0.21825
Mean absolute error in dollars: 7364.10704
epoch: 36


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08229
Mean absolute error: 0.21731
Mean absolute error in dollars: 7397.88765
epoch: 37


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07664
Mean absolute error: 0.20795
Mean absolute error in dollars: 7108.44368
epoch: 38


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07792
Mean absolute error: 0.21006
Mean absolute error in dollars: 7170.99579
epoch: 39


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07914
Mean absolute error: 0.21220
Mean absolute error in dollars: 7239.15709
epoch: 40


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08082
Mean absolute error: 0.21489
Mean absolute error in dollars: 7323.53540
epoch: 41


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07965
Mean absolute error: 0.21290
Mean absolute error in dollars: 7271.59962
epoch: 42


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07741
Mean absolute error: 0.20902
Mean absolute error in dollars: 7191.39176
epoch: 43


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07329
Mean absolute error: 0.20200
Mean absolute error in dollars: 6962.45193
epoch: 44


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07886
Mean absolute error: 0.21178
Mean absolute error in dollars: 7232.95371
epoch: 45


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08084
Mean absolute error: 0.21493
Mean absolute error in dollars: 7324.91269
epoch: 46


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07998
Mean absolute error: 0.21366
Mean absolute error in dollars: 7241.75479
epoch: 47


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07352
Mean absolute error: 0.20232
Mean absolute error in dollars: 6977.07530
epoch: 48


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08016
Mean absolute error: 0.21421
Mean absolute error in dollars: 7236.59795
epoch: 49


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07988
Mean absolute error: 0.21319
Mean absolute error in dollars: 7277.18985
epoch: 50


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07347
Mean absolute error: 0.20239
Mean absolute error in dollars: 6958.41745
epoch: 51


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07827
Mean absolute error: 0.21065
Mean absolute error in dollars: 7195.24059
epoch: 52


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08018
Mean absolute error: 0.21391
Mean absolute error in dollars: 7279.99150
epoch: 53


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08852
Mean absolute error: 0.22782
Mean absolute error in dollars: 7675.33015
epoch: 54


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.09453
Mean absolute error: 0.23772
Mean absolute error in dollars: 7880.01079
epoch: 55


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07542
Mean absolute error: 0.20550
Mean absolute error in dollars: 7083.54521
epoch: 56


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07628
Mean absolute error: 0.20734
Mean absolute error in dollars: 7110.58708
epoch: 57


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08181
Mean absolute error: 0.21669
Mean absolute error in dollars: 7339.47526
epoch: 58


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08080
Mean absolute error: 0.21518
Mean absolute error in dollars: 7277.79058
epoch: 59


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07758
Mean absolute error: 0.20932
Mean absolute error in dollars: 7195.75765
epoch: 60


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08402
Mean absolute error: 0.22063
Mean absolute error in dollars: 7411.61678
epoch: 61


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08348
Mean absolute error: 0.21952
Mean absolute error in dollars: 7473.87147
epoch: 62


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07707
Mean absolute error: 0.20871
Mean absolute error in dollars: 7114.37513
epoch: 63


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07476
Mean absolute error: 0.20483
Mean absolute error in dollars: 7021.04866
epoch: 64


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07922
Mean absolute error: 0.21217
Mean absolute error in dollars: 7243.86681
epoch: 65


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07971
Mean absolute error: 0.21304
Mean absolute error in dollars: 7295.37574
epoch: 66


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08197
Mean absolute error: 0.21716
Mean absolute error in dollars: 7367.86959
epoch: 67


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07673
Mean absolute error: 0.20780
Mean absolute error in dollars: 7185.44331
epoch: 68


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07721
Mean absolute error: 0.20864
Mean absolute error in dollars: 7202.60947
epoch: 69


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07735
Mean absolute error: 0.20906
Mean absolute error in dollars: 7158.95902
epoch: 70


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07286
Mean absolute error: 0.20130
Mean absolute error in dollars: 6944.65433
epoch: 71


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07812
Mean absolute error: 0.21041
Mean absolute error in dollars: 7219.35597
epoch: 72


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07611
Mean absolute error: 0.20696
Mean absolute error in dollars: 7094.96164
epoch: 73


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07927
Mean absolute error: 0.21240
Mean absolute error in dollars: 7239.77350
epoch: 74


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07924
Mean absolute error: 0.21230
Mean absolute error in dollars: 7259.03599
epoch: 75


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07751
Mean absolute error: 0.20951
Mean absolute error in dollars: 7159.87940
epoch: 76


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07844
Mean absolute error: 0.21073
Mean absolute error in dollars: 7234.67745
epoch: 77


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08260
Mean absolute error: 0.21794
Mean absolute error in dollars: 7426.52645
epoch: 78


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08343
Mean absolute error: 0.21966
Mean absolute error in dollars: 7411.51023
epoch: 79


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07421
Mean absolute error: 0.20364
Mean absolute error in dollars: 7021.20358
epoch: 80


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07596
Mean absolute error: 0.20667
Mean absolute error in dollars: 7086.87699
epoch: 81


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08239
Mean absolute error: 0.21803
Mean absolute error in dollars: 7373.45067
epoch: 82


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08079
Mean absolute error: 0.21501
Mean absolute error in dollars: 7302.10859
epoch: 83


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07599
Mean absolute error: 0.20673
Mean absolute error in dollars: 7113.25669
epoch: 84


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07948
Mean absolute error: 0.21283
Mean absolute error in dollars: 7269.04637
epoch: 85


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07970
Mean absolute error: 0.21329
Mean absolute error in dollars: 7239.54929
epoch: 86


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07814
Mean absolute error: 0.21040
Mean absolute error in dollars: 7191.17996
epoch: 87


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07972
Mean absolute error: 0.21308
Mean absolute error in dollars: 7315.89459
epoch: 88


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07880
Mean absolute error: 0.21166
Mean absolute error in dollars: 7229.36242
epoch: 89


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08304
Mean absolute error: 0.21900
Mean absolute error in dollars: 7408.15492
epoch: 90


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.08002
Mean absolute error: 0.21388
Mean absolute error in dollars: 7267.65927
epoch: 91


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07641
Mean absolute error: 0.20735
Mean absolute error in dollars: 7122.20942
epoch: 92


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07912
Mean absolute error: 0.21206
Mean absolute error in dollars: 7246.36647
epoch: 93


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07535
Mean absolute error: 0.20557
Mean absolute error in dollars: 7059.63378
epoch: 94


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07985
Mean absolute error: 0.21335
Mean absolute error in dollars: 7287.23749
epoch: 95


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07767
Mean absolute error: 0.20956
Mean absolute error in dollars: 7219.68803
epoch: 96


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07842
Mean absolute error: 0.21106
Mean absolute error in dollars: 7186.79707
epoch: 97


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07917
Mean absolute error: 0.21226
Mean absolute error in dollars: 7234.03620
epoch: 98


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07877
Mean absolute error: 0.21154
Mean absolute error in dollars: 7225.29722
epoch: 99


  0%|          | 0/6119 [00:00<?, ?it/s]

 results:
Mean square error: 0.07613
Mean absolute error: 0.20732
Mean absolute error in dollars: 7073.35834


### A short report

Please tell us what you did and how did it work.

1. **Increased the Complexity of the Architecture:** I added batch normalization layers, introduced more kernels, and implemented a two-layered convolution over the job description. Additionally, I used pretrained weights for the embedding layer. As a result, the performance dropped because training the optimization algorithm for only a few epochs was insufficient to achieve high accuracy.

2. **Applied Complex Pooling Strategies:** I implemented advanced pooling methods. Combined with longer training using stochastic gradient descent (SGD), the performance improved.

Additional ways to improve the performance:
1. Apply existing tools to correct "obvious" typos in the dataset.
2. Use MSE on the original label scale, not logarithmic one. It seems to me the training then focuses more on the center of the distribution while making larger errors on distributions tails.

## Recommended options

#### A) CNN architecture

All the tricks you know about dense and convolutional neural networks apply here as well.
* Dropout. Nuff said.
* Batch Norm. This time it's `nn.BatchNorm*`/`L.BatchNormalization`
* Parallel convolution layers. The idea is that you apply several nn.Conv1d to the same embeddings and concatenate output channels.
* More layers, more neurons, ya know...


#### B) Play with pooling

There's more than one way to perform pooling:
* Max over time (independently for each feature)
* Average over time (excluding PAD)
* Softmax-pooling:
$$ out_{i, t} = \sum_t {h_{i,t} \cdot {{e ^ {h_{i, t}}} \over \sum_\tau e ^ {h_{j, \tau}} } }$$

* Attentive pooling
$$ out_{i, t} = \sum_t {h_{i,t} \cdot Attn(h_t)}$$

, where $$ Attn(h_t) = {{e ^ {NN_{attn}(h_t)}} \over \sum_\tau e ^ {NN_{attn}(h_\tau)}}  $$
and $NN_{attn}$ is a dense layer.

The optimal score is usually achieved by concatenating several different poolings, including several attentive pooling with different $NN_{attn}$ (aka multi-headed attention).

The catch is that keras layers do not inlude those toys. You will have to [write your own keras layer](https://keras.io/layers/writing-your-own-keras-layers/). Or use pure tensorflow, it might even be easier :)

#### C) Fun with words

It's not always a good idea to train embeddings from scratch. Here's a few tricks:

* Use a pre-trained embeddings from `gensim.downloader.load`. See last lecture.
* Start with pre-trained embeddings, then fine-tune them with gradient descent. You may or may not download pre-trained embeddings from [here](http://nlp.stanford.edu/data/glove.6B.zip) and follow this [manual](https://keras.io/examples/nlp/pretrained_word_embeddings/) to initialize your Keras embedding layer with downloaded weights.
* Use the same embedding matrix in title and desc vectorizer


#### D) Going recurrent

We've already learned that recurrent networks can do cool stuff in sequence modelling. Turns out, they're not useless for classification as well. With some tricks of course..

* Like convolutional layers, LSTM should be pooled into a fixed-size vector with some of the poolings.
* Since you know all the text in advance, use bidirectional RNN
  * Run one LSTM from left to right
  * Run another in parallel from right to left 
  * Concatenate their output sequences along unit axis (dim=-1)

* It might be good idea to mix convolutions and recurrent layers differently for title and description


#### E) Optimizing seriously

* You don't necessarily need 100 epochs. Use early stopping. If you've never done this before, take a look at [early stopping callback(keras)](https://keras.io/callbacks/#earlystopping) or in [pytorch(lightning)](https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html).
  * In short, train until you notice that validation
  * Maintain the best-on-validation snapshot via `model.save(file_name)`
  * Plotting learning curves is usually a good idea
  
Good luck! And may the force be with you!