In [1]:
from config import *
from dataloader import TextDataset
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm, trange

from transformers import BertPreTrainedModel, BertModel
from transformers import AutoConfig, AutoTokenizer, BertConfig

Some weights of the model checkpoint at cointegrated/rubert-tiny were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not i

In [2]:
df = pd.read_csv('../../data/frame2.csv')[['description_clear', 'salary_from_rate_and_gross_log']]

  has_raised = await self.run_ast_nodes(code_ast.body, cell_name,


In [3]:
import numpy as np
df['salary_from_rate_and_gross_log'] = np.exp(df['salary_from_rate_and_gross_log'])

In [4]:
df.reset_index(inplace=True)

In [5]:
class DescriptionDataset(Dataset):

    def __init__(self, data, maxlen, tokenizer): 
        #Store the contents of the file in a pandas dataframe
        self.df = data.reset_index()
        #Initialize the tokenizer for the desired transformer model
        self.tokenizer = tokenizer
        #Maximum length of the tokens list to keep all the sequences of fixed size
        self.maxlen = maxlen

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):    
        #Select the sentence and label at the specified index in the data frame
        excerpt = self.df.loc[index, 'description_clear']
        try:
            target = self.df.loc[index, 'salary_from_rate_and_gross_log']
        except:
            target = 0.0
        identifier = self.df.loc[index, 'index']
        #Preprocess the text to be suitable for the transformer
        tokens = self.tokenizer.tokenize(excerpt) 
        tokens = ['[CLS]'] + tokens + ['[SEP]'] 
        if len(tokens) < self.maxlen:
            tokens = tokens + ['[PAD]' for _ in range(self.maxlen - len(tokens))] 
        else:
            tokens = tokens[:self.maxlen-1] + ['[SEP]'] 
        #Obtain the indices of the tokens in the BERT Vocabulary
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens) 
        input_ids = torch.tensor(input_ids) 
        #Obtain the attention mask i.e a tensor containing 1s for no padded tokens and 0s for padded ones
        attention_mask = (input_ids != 0).long()
        
        target = torch.tensor(target, dtype=torch.float32)
        
        return input_ids, attention_mask, target

In [6]:
class BertRegresser(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        #The output layer that takes the [CLS] representation and gives an output
        self.cls_layer1 = nn.Linear(config.hidden_size,128)
        self.relu1 = nn.ReLU()
        self.ff1 = nn.Linear(128,128)
        self.tanh1 = nn.Tanh()
        self.ff2 = nn.Linear(128,1)

    def forward(self, input_ids, attention_mask):
        #Feed the input to Bert model to obtain contextualized representations
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        #Obtain the representations of [CLS] heads
        logits = outputs.last_hidden_state[:,0,:]
        output = self.cls_layer1(logits)
        output = self.relu1(output)
        output = self.ff1(output)
        output = self.tanh1(output)
        output = self.ff2(output)
        return output, outputs

In [7]:
def train(model, criterion, optimizer, train_loader, val_loader, epochs, device):
    best_acc = 0
    print('Start')
    for epoch in trange(epochs, desc="Epoch"):
        model.train()
        train_loss = 0
        train_rmse = 0
        for i, (input_ids, attention_mask, target) in enumerate(iterable=train_loader):
            optimizer.zero_grad()  
            
            input_ids, attention_mask, target = input_ids.to(device), attention_mask.to(device), target.to(device)
            
            output, _ = model(input_ids=input_ids, attention_mask=attention_mask)
            target= torch.unsqueeze(target, dim=1)
            loss = criterion(output, target.type_as(output))
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_rmse += get_rmse(output, target.type_as(output), criterion=criterion)
            if i % 100 == 0:
                print(f'Batch:{i}/{len(train_loader)}')
        
        print(f"Training loss is {train_loss/len(train_loader)}\trmse is {train_rmse/len(train_loader)}")
        val_loss, val_rmse = evaluate(model=model, criterion=criterion, dataloader=val_loader, device=device)
        print("Epoch {} complete! Validation Loss : {}".format(epoch, val_loss))
        print(f"Epoch {epoch} complete! Validation loss is {val_loss}\trmse is {val_rmse}")

In [8]:
def evaluate(model, criterion, dataloader, device):
    model.eval()
    mean_acc, mean_loss, mean_err, count = 0, 0, 0, 0
    with torch.no_grad():
        for input_ids, attention_mask, target in (dataloader):
            
            input_ids, attention_mask, target = input_ids.to(device), attention_mask.to(device), target.to(device)
            output, _ = model(input_ids, attention_mask)
            target= torch.unsqueeze(target, dim=1)
            mean_loss += criterion(output, target.type_as(output)).item()
            mean_err += get_rmse(output, target, criterion=criterion)
            count += 1
            
    return mean_loss/count, mean_err/count

In [9]:
from sklearn import metrics

def get_rmse(output, target, criterion):
    with torch.no_grad():
        err = torch.sqrt(criterion(target, output))
        return err

In [10]:
def predict(model, dataloader, device):
    predicted_label = []
    actual_label = []
    with torch.no_grad():
        for input_ids, attention_mask, target in (dataloader):
            
            input_ids, attention_mask, target = input_ids.to(device), attention_mask.to(device), target.to(device)
            output, _ = model(input_ids, attention_mask)
                        
            predicted_label += output
            actual_label += target
            
    return predicted_label

In [11]:
train_dataset, valid_dataset = train_test_split(df, test_size=0.3, random_state=0)
valid_dataset, test_dataset = train_test_split(valid_dataset, test_size=0.5, random_state=0)

In [12]:
config = BertConfig.from_pretrained("cointegrated/rubert-tiny", output_hidden_states=True)
## Tokenizer loaded from AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny")
## Creating the model from the desired transformer model
model = BertRegresser.from_pretrained("cointegrated/rubert-tiny", config=config)
## GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
## Putting model to device
model = model.to(device)
## Takes as the input the logits of the positive class and computes the binary cross-entropy 
# criterion = nn.BCEWithLogitsLoss()
criterion = nn.MSELoss()
## Optimizer
optimizer = optim.Adam(params=model.parameters(), lr=1e-3)

Some weights of the model checkpoint at cointegrated/rubert-tiny were not used when initializing BertRegresser: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertRegresser from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertRegresser from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertRegresser were not initialized from the model checkpoint at cointegrated/rubert-tiny

In [13]:
train_set = DescriptionDataset(data=train_dataset, maxlen=256, tokenizer=tokenizer)
valid_set = DescriptionDataset(data=valid_dataset, maxlen=256, tokenizer=tokenizer)
test_set = DescriptionDataset(data=test_dataset, maxlen=256, tokenizer=tokenizer)

In [14]:
train_loader = DataLoader(dataset=train_set, batch_size=50)
valid_loader = DataLoader(dataset=valid_set, batch_size=50)
test_loader = DataLoader(dataset=test_set, batch_size=50)

In [15]:
train(model=model, 
      criterion=criterion,
      optimizer=optimizer, 
      train_loader=train_loader,
      val_loader=valid_loader,
      epochs = 50,
     device = device)

Start


Epoch:   0%|          | 0/50 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (573 > 512). Running this sequence through the model will result in indexing errors


Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5631816106.783083	rmse is 74741.2265625


Epoch:   2%|▏         | 1/50 [05:38<4:36:28, 338.53s/it]

Epoch 0 complete! Validation Loss : 5596397018.576271
Epoch 0 complete! Validation loss is 5596397018.576271	rmse is 74535.5390625
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5594965657.087768	rmse is 74493.8203125


Epoch:   4%|▍         | 2/50 [11:15<4:29:55, 337.41s/it]

Epoch 1 complete! Validation Loss : 5559727154.983051
Epoch 1 complete! Validation loss is 5559727154.983051	rmse is 74288.875
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5558337307.474306	rmse is 74247.3125


Epoch:   6%|▌         | 3/50 [16:52<4:24:13, 337.31s/it]

Epoch 2 complete! Validation Loss : 5523233310.915255
Epoch 2 complete! Validation loss is 5523233310.915255	rmse is 74042.5234375
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5521876000.713052	rmse is 74000.890625


Epoch:   8%|▊         | 4/50 [22:28<4:18:09, 336.74s/it]

Epoch 3 complete! Validation Loss : 5486901181.288136
Epoch 3 complete! Validation loss is 5486901181.288136	rmse is 73796.4921875
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5485576063.243293	rmse is 73754.9375


Epoch:  10%|█         | 5/50 [28:04<4:12:26, 336.59s/it]

Epoch 4 complete! Validation Loss : 5450730251.389831
Epoch 4 complete! Validation loss is 5450730251.389831	rmse is 73550.7578125
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5449436878.639381	rmse is 73509.2109375


Epoch:  12%|█▏        | 6/50 [33:54<4:10:13, 341.23s/it]

Epoch 5 complete! Validation Loss : 5414719771.118644
Epoch 5 complete! Validation loss is 5414719771.118644	rmse is 73305.1875
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5413458427.692588	rmse is 73263.703125


Epoch:  14%|█▍        | 7/50 [39:39<4:05:22, 342.39s/it]

Epoch 6 complete! Validation Loss : 5378870204.745763
Epoch 6 complete! Validation loss is 5378870204.745763	rmse is 73060.0
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5377641051.736244	rmse is 73018.5234375


Epoch:  16%|█▌        | 8/50 [45:24<4:00:08, 343.06s/it]

Epoch 7 complete! Validation Loss : 5343182336.0
Epoch 7 complete! Validation loss is 5343182336.0	rmse is 72814.9921875
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5341984535.050477	rmse is 72773.609375


Epoch:  18%|█▊        | 9/50 [51:08<3:54:43, 343.50s/it]

Epoch 8 complete! Validation Loss : 5307654755.254237
Epoch 8 complete! Validation loss is 5307654755.254237	rmse is 72570.328125
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5306488068.656662	rmse is 72528.859375


Epoch:  20%|██        | 10/50 [56:53<3:49:20, 344.01s/it]

Epoch 9 complete! Validation Loss : 5272286211.79661
Epoch 9 complete! Validation loss is 5272286211.79661	rmse is 72325.90625
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5271151765.94452	rmse is 72284.6171875


Epoch:  22%|██▏       | 11/50 [1:02:38<3:43:50, 344.36s/it]

Epoch 10 complete! Validation Loss : 5237078784.542373
Epoch 10 complete! Validation loss is 5237078784.542373	rmse is 72081.84375
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5235975865.917235	rmse is 72040.4453125


Epoch:  24%|██▍       | 12/50 [1:08:24<3:38:22, 344.79s/it]

Epoch 11 complete! Validation Loss : 5202031951.18644
Epoch 11 complete! Validation loss is 5202031951.18644	rmse is 71837.984375
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5200960575.912687	rmse is 71796.65625


Epoch:  26%|██▌       | 13/50 [1:14:10<3:32:46, 345.05s/it]

Epoch 12 complete! Validation Loss : 5167145101.016949
Epoch 12 complete! Validation loss is 5167145101.016949	rmse is 71594.421875
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5166105942.963165	rmse is 71553.1484375


Epoch:  28%|██▊       | 14/50 [1:19:46<3:25:22, 342.28s/it]

Epoch 13 complete! Validation Loss : 5132419905.084745
Epoch 13 complete! Validation loss is 5132419905.084745	rmse is 71351.1953125
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5131411893.144156	rmse is 71309.953125


Epoch:  30%|███       | 15/50 [1:25:31<3:20:14, 343.26s/it]

Epoch 14 complete! Validation Loss : 5097853935.18644
Epoch 14 complete! Validation loss is 5097853935.18644	rmse is 71108.2890625
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5096878496.189177	rmse is 71067.046875


Epoch:  32%|███▏      | 16/50 [1:31:08<3:13:27, 341.39s/it]

Epoch 15 complete! Validation Loss : 5063449684.067797
Epoch 15 complete! Validation loss is 5063449684.067797	rmse is 70865.609375
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5062506189.475216	rmse is 70824.40625


Epoch:  34%|███▍      | 17/50 [1:36:44<3:06:54, 339.82s/it]

Epoch 16 complete! Validation Loss : 5029206619.118644
Epoch 16 complete! Validation loss is 5029206619.118644	rmse is 70623.2890625
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 5028294559.490678	rmse is 70582.109375


Epoch:  36%|███▌      | 18/50 [1:42:18<3:00:14, 337.94s/it]

Epoch 17 complete! Validation Loss : 4995124513.627119
Epoch 17 complete! Validation loss is 4995124513.627119	rmse is 70381.2734375
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4994243716.016371	rmse is 70340.078125


Epoch:  38%|███▊      | 19/50 [1:48:05<2:56:02, 340.74s/it]

Epoch 18 complete! Validation Loss : 4961202044.20339
Epoch 18 complete! Validation loss is 4961202044.20339	rmse is 70139.53125
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4960353374.646658	rmse is 70098.546875


Epoch:  40%|████      | 20/50 [1:53:54<2:51:32, 343.07s/it]

Epoch 19 complete! Validation Loss : 4927440081.355932
Epoch 19 complete! Validation loss is 4927440081.355932	rmse is 69898.109375
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4926623544.345612	rmse is 69857.1171875


Epoch:  42%|████▏     | 21/50 [1:59:43<2:46:41, 344.87s/it]

Epoch 20 complete! Validation Loss : 4893839947.932203
Epoch 20 complete! Validation loss is 4893839947.932203	rmse is 69657.0390625
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4893054777.975444	rmse is 69615.9921875


Epoch:  44%|████▍     | 22/50 [2:05:39<2:42:29, 348.19s/it]

Epoch 21 complete! Validation Loss : 4860399880.135593
Epoch 21 complete! Validation loss is 4860399880.135593	rmse is 69416.234375
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4859646590.079127	rmse is 69375.171875


Epoch:  46%|████▌     | 23/50 [2:11:36<2:37:53, 350.86s/it]

Epoch 22 complete! Validation Loss : 4827120713.220339
Epoch 22 complete! Validation loss is 4827120713.220339	rmse is 69175.796875
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4826399088.574807	rmse is 69134.7109375


Epoch:  48%|████▊     | 24/50 [2:17:29<2:32:21, 351.59s/it]

Epoch 23 complete! Validation Loss : 4794001535.457627
Epoch 23 complete! Validation loss is 4794001535.457627	rmse is 68935.6953125
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4793311698.364711	rmse is 68894.65625


Epoch:  50%|█████     | 25/50 [2:23:28<2:27:26, 353.87s/it]

Epoch 24 complete! Validation Loss : 4761043105.627119
Epoch 24 complete! Validation loss is 4761043105.627119	rmse is 68695.875
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4760384715.845385	rmse is 68654.859375


Epoch:  52%|█████▏    | 26/50 [2:29:17<2:20:54, 352.27s/it]

Epoch 25 complete! Validation Loss : 4728244629.694915
Epoch 25 complete! Validation loss is 4728244629.694915	rmse is 68456.4140625
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4727618384.560255	rmse is 68415.4609375


Epoch:  54%|█████▍    | 27/50 [2:34:59<2:13:53, 349.28s/it]

Epoch 26 complete! Validation Loss : 4695607719.050847
Epoch 26 complete! Validation loss is 4695607719.050847	rmse is 68217.28125
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4695013182.515689	rmse is 68176.3984375


Epoch:  56%|█████▌    | 28/50 [2:40:38<2:06:54, 346.14s/it]

Epoch 27 complete! Validation Loss : 4663130762.847458
Epoch 27 complete! Validation loss is 4663130762.847458	rmse is 67978.46875
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4662567976.163711	rmse is 67937.5625


Epoch:  58%|█████▊    | 29/50 [2:46:27<2:01:25, 346.93s/it]

Epoch 28 complete! Validation Loss : 4630814033.355932
Epoch 28 complete! Validation loss is 4630814033.355932	rmse is 67740.0625
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4630283733.973624	rmse is 67699.21875


Epoch:  60%|██████    | 30/50 [2:52:14<1:55:39, 346.98s/it]

Epoch 29 complete! Validation Loss : 4598658411.389831
Epoch 29 complete! Validation loss is 4598658411.389831	rmse is 67501.9375
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4598159260.69668	rmse is 67461.171875


Epoch:  62%|██████▏   | 31/50 [2:58:06<1:50:20, 348.45s/it]

Epoch 30 complete! Validation Loss : 4566662627.254237
Epoch 30 complete! Validation loss is 4566662627.254237	rmse is 67264.1640625
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4566195099.416099	rmse is 67223.46875


Epoch:  64%|██████▍   | 32/50 [3:03:53<1:44:27, 348.21s/it]

Epoch 31 complete! Validation Loss : 4534826848.542373
Epoch 31 complete! Validation loss is 4534826848.542373	rmse is 67026.78125
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4534390806.817644	rmse is 66986.0625


Epoch:  66%|██████▌   | 33/50 [3:09:40<1:38:31, 347.72s/it]

Epoch 32 complete! Validation Loss : 4503151656.135593
Epoch 32 complete! Validation loss is 4503151656.135593	rmse is 66789.7265625
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4502747484.90041	rmse is 66749.0234375


Epoch:  68%|██████▊   | 34/50 [3:15:26<1:32:37, 347.36s/it]

Epoch 33 complete! Validation Loss : 4471637309.830508
Epoch 33 complete! Validation loss is 4471637309.830508	rmse is 66552.9921875
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4471264909.096862	rmse is 66512.3828125


Epoch:  70%|███████   | 35/50 [3:21:12<1:26:44, 346.98s/it]

Epoch 34 complete! Validation Loss : 4440283375.18644
Epoch 34 complete! Validation loss is 4440283375.18644	rmse is 66316.7421875
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4439942773.871759	rmse is 66276.1796875


Epoch:  72%|███████▏  | 36/50 [3:26:59<1:20:55, 346.83s/it]

Epoch 35 complete! Validation Loss : 4409089806.644068
Epoch 35 complete! Validation loss is 4409089806.644068	rmse is 66080.8046875
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4408780715.190541	rmse is 66040.2421875


Epoch:  74%|███████▍  | 37/50 [3:32:45<1:15:07, 346.69s/it]

Epoch 36 complete! Validation Loss : 4378056266.305085
Epoch 36 complete! Validation loss is 4378056266.305085	rmse is 65845.2265625
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4377779352.796726	rmse is 65804.765625


Epoch:  76%|███████▌  | 38/50 [3:38:33<1:09:24, 347.05s/it]

Epoch 37 complete! Validation Loss : 4347183972.338983
Epoch 37 complete! Validation loss is 4347183972.338983	rmse is 65610.0
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4346938947.929059	rmse is 65569.546875


Epoch:  78%|███████▊  | 39/50 [3:44:20<1:03:37, 347.07s/it]

Epoch 38 complete! Validation Loss : 4316471952.81356
Epoch 38 complete! Validation loss is 4316471952.81356	rmse is 65375.17578125
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4316259203.841746	rmse is 65334.70703125


Epoch:  80%|████████  | 40/50 [3:50:07<57:50, 347.02s/it]  

Epoch 39 complete! Validation Loss : 4285921392.2711864
Epoch 39 complete! Validation loss is 4285921392.2711864	rmse is 65140.765625
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4285739860.6930423	rmse is 65100.3828125


Epoch:  82%|████████▏ | 41/50 [3:55:54<52:02, 346.99s/it]

Epoch 40 complete! Validation Loss : 4255530748.745763
Epoch 40 complete! Validation loss is 4255530748.745763	rmse is 64906.7421875
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4255380961.8481126	rmse is 64866.4140625


Epoch:  84%|████████▍ | 42/50 [4:01:43<46:19, 347.47s/it]

Epoch 41 complete! Validation Loss : 4225300269.559322
Epoch 41 complete! Validation loss is 4225300269.559322	rmse is 64673.08203125
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4225181860.671214	rmse is 64632.7578125


Epoch:  86%|████████▌ | 43/50 [4:07:43<40:58, 351.27s/it]

Epoch 42 complete! Validation Loss : 4195230065.898305
Epoch 42 complete! Validation loss is 4195230065.898305	rmse is 64439.83203125
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4195143611.1978173	rmse is 64399.5703125


Epoch:  88%|████████▊ | 44/50 [4:13:51<35:38, 356.45s/it]

Epoch 43 complete! Validation Loss : 4165320617.762712
Epoch 43 complete! Validation loss is 4165320617.762712	rmse is 64206.984375
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4165266169.3060484	rmse is 64166.73828125


Epoch:  90%|█████████ | 45/50 [4:19:42<29:32, 354.54s/it]

Epoch 44 complete! Validation Loss : 4135572863.1864405
Epoch 44 complete! Validation loss is 4135572863.1864405	rmse is 63974.58203125
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4135549549.3733516	rmse is 63934.2734375


Epoch:  92%|█████████▏| 46/50 [4:25:34<23:35, 353.88s/it]

Epoch 45 complete! Validation Loss : 4105984576.8135595
Epoch 45 complete! Validation loss is 4105984576.8135595	rmse is 63742.546875
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4105993788.827649	rmse is 63702.41015625


Epoch:  94%|█████████▍| 47/50 [4:31:22<17:36, 352.26s/it]

Epoch 46 complete! Validation Loss : 4076557862.779661
Epoch 46 complete! Validation loss is 4076557862.779661	rmse is 63510.9375
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4076598621.075034	rmse is 63470.76953125


Epoch:  96%|█████████▌| 48/50 [4:37:14<11:44, 352.10s/it]

Epoch 47 complete! Validation Loss : 4047291531.3898306
Epoch 47 complete! Validation loss is 4047291531.3898306	rmse is 63279.796875
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4047363556.8749433	rmse is 63239.71875


Epoch:  98%|█████████▊| 49/50 [4:43:13<05:54, 354.01s/it]

Epoch 48 complete! Validation Loss : 4018185019.9322033
Epoch 48 complete! Validation loss is 4018185019.9322033	rmse is 63048.9765625
Batch:0/2199
Batch:100/2199
Batch:200/2199
Batch:300/2199
Batch:400/2199
Batch:500/2199
Batch:600/2199
Batch:700/2199
Batch:800/2199
Batch:900/2199
Batch:1000/2199
Batch:1100/2199
Batch:1200/2199
Batch:1300/2199
Batch:1400/2199
Batch:1500/2199
Batch:1600/2199
Batch:1700/2199
Batch:1800/2199
Batch:1900/2199
Batch:2000/2199
Batch:2100/2199
Training loss is 4018288796.5220556	rmse is 63008.9375


Epoch: 100%|██████████| 50/50 [4:49:18<00:00, 347.16s/it]

Epoch 49 complete! Validation Loss : 3989239307.118644
Epoch 49 complete! Validation loss is 3989239307.118644	rmse is 62818.6796875





In [17]:
torch.save(model.state_dict(), 'regression_model/third.pth')

In [18]:
bert = BertRegresser.from_pretrained("cointegrated/rubert-tiny", config=config)
bert.load_state_dict(torch.load('regression_model/third.pth'))

Some weights of the model checkpoint at cointegrated/rubert-tiny were not used when initializing BertRegresser: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertRegresser from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertRegresser from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertRegresser were not initialized from the model checkpoint at cointegrated/rubert-tiny

<All keys matched successfully>

In [74]:
train_set[0][1].unsqueeze(0)

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

In [19]:
model.eval()
with torch.no_grad():
    output, z = model(train_set[0][0].to(device).unsqueeze(0), train_set[0][1].to(device).unsqueeze(0))
# for input_ids, attention_mask, target in (dataloader):
            
#             input_ids, attention_mask, target = input_ids.to(device), attention_mask.to(device), target.to(device)
#             output = model(input_ids, attention_mask)

In [20]:
z['last_hidden_state'][0][0]

tensor([ 1.0165, -0.5941,  0.8446, -1.0152, -1.1903,  0.8911, -1.0150,  0.6692,
        -0.0823,  1.0956,  1.3675, -0.9616, -0.6173,  1.1120,  0.2065,  0.7564,
         0.0050, -0.9943, -0.8755, -1.4005, -0.9810,  0.7830, -1.1040, -0.6733,
         1.1221, -0.4395,  1.2176,  0.8622,  1.1510,  0.9769, -0.9694,  1.0767,
        -0.6339,  0.9127, -0.0927, -0.8705,  0.8907, -1.0492,  1.1331, -0.2640,
        -1.3647,  0.4010,  1.4742, -1.0280, -0.9186,  0.5417, -0.6405,  0.9277,
        -1.1187, -1.3352,  1.1012, -1.1103,  0.9032,  0.9564, -1.3408,  1.0993,
         1.1243,  0.9001, -1.2059, -0.8907,  1.1102,  0.9942,  0.8966, -0.7747,
         1.0933,  1.4284, -1.1696, -0.8015,  0.7859, -0.9075, -1.1437, -0.9550,
        -1.3351, -0.9288,  0.9722,  1.1186,  1.0152,  0.4491,  1.1752, -0.5556,
         1.1111, -0.8800, -1.0476, -0.6195,  0.7390,  0.8985, -1.0188, -0.8994,
        -0.0845, -0.9677, -0.9305, -1.1115,  1.7626,  1.0709,  1.0571,  0.6062,
         1.2125,  1.1210, -1.0072,  1.06

In [22]:
bert.cuda()
bert.eval()

with torch.no_grad():
    output, z = bert(train_set[0][0].to(device).unsqueeze(0), train_set[0][1].to(device).unsqueeze(0))

In [23]:
z['last_hidden_state'][0][0]

tensor([ 1.0165, -0.5941,  0.8446, -1.0152, -1.1903,  0.8911, -1.0150,  0.6692,
        -0.0823,  1.0956,  1.3675, -0.9616, -0.6173,  1.1120,  0.2065,  0.7564,
         0.0050, -0.9943, -0.8755, -1.4005, -0.9810,  0.7830, -1.1040, -0.6733,
         1.1221, -0.4395,  1.2176,  0.8622,  1.1510,  0.9769, -0.9694,  1.0767,
        -0.6339,  0.9127, -0.0927, -0.8705,  0.8907, -1.0492,  1.1331, -0.2640,
        -1.3647,  0.4010,  1.4742, -1.0280, -0.9186,  0.5417, -0.6405,  0.9277,
        -1.1187, -1.3352,  1.1012, -1.1103,  0.9032,  0.9564, -1.3408,  1.0993,
         1.1243,  0.9001, -1.2059, -0.8907,  1.1102,  0.9942,  0.8966, -0.7747,
         1.0933,  1.4284, -1.1696, -0.8015,  0.7859, -0.9075, -1.1437, -0.9550,
        -1.3351, -0.9288,  0.9722,  1.1186,  1.0152,  0.4491,  1.1752, -0.5556,
         1.1111, -0.8800, -1.0476, -0.6195,  0.7390,  0.8985, -1.0188, -0.8994,
        -0.0845, -0.9677, -0.9305, -1.1115,  1.7626,  1.0709,  1.0571,  0.6062,
         1.2125,  1.1210, -1.0072,  1.06

In [22]:
torch.save(model, 'regression_model/second.pth')

In [6]:
import torch
model_new = torch.load('regression_model/second.pth', map_location='cuda')

PermissionError: [Errno 13] Permission denied: 'regression_model/checkpoint-22000'

In [26]:
model.eval()
with torch.no_grad():
    output, z = model(train_set[0][0].to(device).unsqueeze(0), train_set[0][1].to(device).unsqueeze(0))

In [27]:
output

tensor([[8466.5273]], device='cuda:0')

In [7]:
from transformers import AutoModelForSequenceClassification
model_new = AutoModelForSequenceClassification.from_pretrained('regression_model/checkpoint-20900/', local_files_only=True)