In [58]:
import pandas as pd
import pyarrow as pa
import numpy as np

import sentencepiece as spm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from tqdm import tqdm
import io
import pickle
from torch.utils.data import DataLoader, Dataset
import os

In [3]:
# Load Parquet data
data = pd.read_parquet('/root/MLXWeek2_copy/hacker_news_data.parquet')
# Now, you can use the `data' DataFrame to analyse and manipulate the data.
print(data.head())

       id                                              title  score   type
0  187635                     ReactOS, a clone of Windows XP      1  story
1  187637  Music Operates Directly On Your Abstract Synta...      1  story
2  187788   Nudge: Helping you make rational daily decisions      1  story
3  187813           Scaling a Microblogging Service - Part I     14  story
4  188029        Good analysis of Powerset by Danny Sullivan      2  story


In [4]:
print(data.columns)

Index(['id', 'title', 'score', 'type'], dtype='object')


In [5]:
# Extract the columns with the correct rows
story_data = data[data["type"] == "story"]

# Extract the relevant columns
relevant_data = story_data[["title", "score"]]

# Remove any null entries
processed_data = relevant_data.dropna()

print(processed_data.head())

                                               title  score
0                     ReactOS, a clone of Windows XP      1
1  Music Operates Directly On Your Abstract Synta...      1
2   Nudge: Helping you make rational daily decisions      1
3           Scaling a Microblogging Service - Part I     14
4        Good analysis of Powerset by Danny Sullivan      2


In [6]:
# Convert to array
data_array = processed_data.to_numpy()
print(data_array)

[['ReactOS, a clone of Windows XP' 1]
 ['Music Operates Directly On Your Abstract Syntax Tree' 1]
 ['Nudge: Helping you make rational daily decisions' 1]
 ...
 ['How to write great Javascript Unit Tests' 6]
 ['Discount Climbing Gear' 1]
 ['Feature Request: Google please unify your Gtalk gadget!' 1]]


In [7]:
# Get the titles
title_array = data_array[:, 0]

In [43]:
# Load the trained SentencePiece model
sp = spm.SentencePieceProcessor()
sp.load('/root/MLXWeek2_copy/weights/skipgram_TechCrunch/techcrunch_sp.model')

# Tokenize titles
tokenized_titles = [sp.encode_as_pieces(title) for title in title_array]

# Turn the tokens into tokenids
id_titles = [[sp.piece_to_id(token) for token in title_tokens] for title_tokens in tokenized_titles]

In [44]:
# A `Dataset` class for generating pairs of id-tokenised titles and scores one at a time, to use in the `DataLoader` to generate mini-batches
class TitlesAndScores(Dataset):
    def __init__(self, titles, scores):
        """
        Args:
            token_ids (list of list of words): Nested list where each sublist contains token IDs for a sentence.
            scores (list of float): List of scores associated with each list of token IDs.
        """
        self.titles = titles
        self.tokenized_titles = [sp.encode_as_pieces(title) for title in self.titles]
        self.id_titles = [[sp.piece_to_id(token) for token in title_tokens] for title_tokens in self.tokenized_titles]
        self.scores = scores
    
    def __len__(self):
        return len(self.scores)
    
    def __getitem__(self, idx):
        return self.id_titles[idx], self.scores[idx]

In [10]:
# A padding function to make sure that every list of title token ids in a mini-batch is the same length, so that the batch can be tensorised
def pad_collate(batch):
    (id_titles, scores) = zip(*batch)
    
    # Padding the sequences with 0
    padded_titles = pad_sequence([torch.tensor(title_ids) for title_ids in id_titles], batch_first=True, padding_value=0)
    
    # Convert scores to a tensor
    tensor_scores = torch.tensor(scores, dtype=torch.float)

    return padded_titles, tensor_scores

In [11]:
# Create the dataset on our hackernews data
sp_dataset = TitlesAndScores(data_array[:,0], data_array[:,1])

# Create the dataLoader on our hackernews data
sp_dataloader = DataLoader(sp_dataset, batch_size=32, shuffle=True, collate_fn=pad_collate)

In [108]:
# Hyperparameters
hidden_dim_1= 64
hidden_dim_2 = 32
num_epochs = 100
vocab_size = 4000
embedding_dim = 300
learning_rate = 0.001
EMBED_MAX_NORM = 1

In [86]:
# Define the SkipGramModel class
class SkipGram_Model(nn.Module):
    def __init__(self, vocab_size, embedding_dim, max_norm=1):
        super(SkipGram_Model, self).__init__()
        self.vocab_size = 4000
        self.embedding_dim = embedding_dim
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, max_norm=max_norm)
        self.linear = nn.Linear(in_features=embedding_dim, out_features=vocab_size)
        
    def forward(self, inputs_):
        x = self.embeddings(inputs_)
        x = self.linear(x)
        return x
    
    def embed(self, inputs_):
        # This method will return just the embeddings
        return self.embeddings(inputs_)



In [47]:
class CustomUnpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == "utils":
            return super().find_class(__name__, name)
        return super().find_class(module, name)


In [87]:
skipgram_model = SkipGram_Model(vocab_size, embedding_dim)
skipgram_model = SkipGram_Model(vocab_size, embedding_dim)
print(f"Embedding dimension: {skipgram_model.embedding_dim}")

Embedding dimension: 300


In [88]:
for inputs, targets in sp_dataloader:
    print(f'Max input index: {inputs.max().item()}')  # Add this line to check the maximum index
    inputs = inputs.long()

Max input index: 3973
Max input index: 3844
Max input index: 3955
Max input index: 3926
Max input index: 3982
Max input index: 3985
Max input index: 3884
Max input index: 3955
Max input index: 3834
Max input index: 3845
Max input index: 3882
Max input index: 3955
Max input index: 3982
Max input index: 3879
Max input index: 3987
Max input index: 3963
Max input index: 3882
Max input index: 3975
Max input index: 3897
Max input index: 3982
Max input index: 3964
Max input index: 3854
Max input index: 3815
Max input index: 3975
Max input index: 3882
Max input index: 3987
Max input index: 3898
Max input index: 3987
Max input index: 3987
Max input index: 3985
Max input index: 3964
Max input index: 3757
Max input index: 3818
Max input index: 3902
Max input index: 3941
Max input index: 3846
Max input index: 3859
Max input index: 3958
Max input index: 3876
Max input index: 3762
Max input index: 3915
Max input index: 3735
Max input index: 3987
Max input index: 3963
Max input index: 3975
Max input 

Max input index: 3664
Max input index: 3987
Max input index: 3979
Max input index: 3888
Max input index: 3915
Max input index: 3898
Max input index: 3985
Max input index: 3981
Max input index: 3926
Max input index: 3898
Max input index: 3985
Max input index: 3975
Max input index: 3947
Max input index: 3986
Max input index: 3915
Max input index: 3876
Max input index: 3973
Max input index: 3909
Max input index: 3875
Max input index: 3957
Max input index: 3842
Max input index: 3987
Max input index: 3881
Max input index: 3975
Max input index: 3951
Max input index: 3807
Max input index: 3893
Max input index: 3987
Max input index: 3868
Max input index: 3985
Max input index: 3987
Max input index: 3853
Max input index: 3959
Max input index: 3968
Max input index: 3973
Max input index: 3868
Max input index: 3762
Max input index: 3887
Max input index: 3926
Max input index: 3712
Max input index: 3975
Max input index: 3987
Max input index: 3847
Max input index: 3868
Max input index: 3987
Max input 

In [89]:

# Load the pre-trained SkipGram model
pretrained_model_path = '/root/MLXWeek2_copy/weights/skipgram_TechCrunch/model.pt'
skipgram_model = torch.load(pretrained_model_path, map_location=torch.device('cpu'))
skipgram_model.eval()  # Assuming you're doing inference

# Manually set the embedding_dim if it's not part of the loaded state_dict
skipgram_model.embedding_dim = 300

# Manually add the embed method to the loaded model if it's missing
def embed(self, inputs_):
    return self.embeddings(inputs_)

# Assign the method to the instance
skipgram_model.embed = embed.__get__(skipgram_model, SkipGram_Model)

# Freeze the parameters (weights) of the embedding model
for param in skipgram_model.parameters():
    param.requires_grad = False

# Now skipgram_model has an embedding_dim attribute you can access
print(f"Embedding dimension: {skipgram_model.embedding_dim}")

Embedding dimension: 300


In [135]:
# A model for predicting scores given titles. 
# This particular architecture is intended to use one of the word2vec models we trained earlier (CBOW or Skip-gram) -- with its weights frozen -- to embed the input title tokens.
# It then uses a classic neural network architecture to predict the scores given the embeddings.
# This architecture has two hidden layers with ReLU activations, and a final output layer with no activation (a linear output layer) because we are doing a regression task. 
class ScorePredictor(nn.Module):
    def __init__(self, hidden_dim_1, hidden_dim_2, embedding_model):
        super(ScorePredictor, self).__init__()
        self.embedding_model = embedding_model
        
        # Use embedding_model's embedding dimension
        embedding_dim = self.embedding_model.embedding_dim
        
        self.hidden_1 = nn.Linear(embedding_dim, hidden_dim_1)
        self.hidden_2 = nn.Linear(hidden_dim_1, hidden_dim_2)
        self.output = nn.Linear(hidden_dim_2, 1)
        self.relu = nn.ReLU()

    def forward(self, titles):
        embeddings = self.embedding_model.embed(titles)
        pooled_embeddings = embeddings.mean(dim=1)
        hidden_1_output = self.relu(self.hidden_1(pooled_embeddings))
        hidden_2_output = self.relu(self.hidden_2(hidden_1_output))
        score_predictions = self.output(hidden_2_output)
        return score_predictions

In [117]:
# Model instantiation
sp_model = ScorePredictor(hidden_dim_1, hidden_dim_2, skipgram_model)

# Loss function
sp_loss_function = nn.MSELoss()  # Mean Squared Error Loss for regression

# Optimizer - Excluding the frozen embeddings from being updated
sp_optimizer = optim.Adam(filter(lambda p: p.requires_grad, sp_model.parameters()), lr=learning_rate)

# Training Loop
for epoch in range(num_epochs):
    total_loss = 0
    for inputs, targets in tqdm(sp_dataloader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        
        inputs = inputs.long()
        sp_model.zero_grad()  # Reset gradients
        outputs = sp_model(inputs)  # Forward pass
        loss = sp_loss_function(outputs, targets.view(outputs.size()))  # Compute loss
        loss.backward()  # Backpropagation
        sp_optimizer.step()  # Update the parameters
        total_loss += loss.item()

    # Print the average loss for the epoch
    avg_loss = total_loss / len(sp_dataloader)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss}')
    
    # Save the model state after each epoch, excluding the frozen embeddings
    torch.save({k: v for k, v in sp_model.state_dict().items() if k != 'embedding_model.embeddings.weight'}, os.path.join("sp_skipgram_weights", f"sp_skipgram_epoch_{epoch+1}.pth"))

    torch.save({
    'score_predictor_state_dict': sp_model.state_dict(),
    'embedding_model_state_dict': skipgram_model.state_dict()
}, '/root/MLXWeek2_copy/combined_model.pth')
    

Epoch 1/100: 100%|██████████| 5170/5170 [00:17<00:00, 292.65it/s]


Epoch 1/100, Loss: 269.09864696524824


Epoch 2/100: 100%|██████████| 5170/5170 [00:17<00:00, 292.64it/s]


Epoch 2/100, Loss: 265.714672713953


Epoch 3/100: 100%|██████████| 5170/5170 [00:17<00:00, 291.73it/s]


Epoch 3/100, Loss: 264.64463471093535


Epoch 4/100: 100%|██████████| 5170/5170 [00:17<00:00, 291.25it/s]


Epoch 4/100, Loss: 264.28166258164487


Epoch 5/100: 100%|██████████| 5170/5170 [00:17<00:00, 291.70it/s]


Epoch 5/100, Loss: 264.04702504100834


Epoch 6/100: 100%|██████████| 5170/5170 [00:17<00:00, 290.97it/s]


Epoch 6/100, Loss: 263.83881431096296


Epoch 7/100: 100%|██████████| 5170/5170 [00:17<00:00, 292.60it/s]


Epoch 7/100, Loss: 263.5806771604881


Epoch 8/100: 100%|██████████| 5170/5170 [00:17<00:00, 292.38it/s]


Epoch 8/100, Loss: 263.5127479457302


Epoch 9/100: 100%|██████████| 5170/5170 [00:17<00:00, 299.27it/s]


Epoch 9/100, Loss: 263.3687802598831


Epoch 10/100: 100%|██████████| 5170/5170 [00:17<00:00, 303.10it/s]


Epoch 10/100, Loss: 263.0677881192884


Epoch 11/100: 100%|██████████| 5170/5170 [00:17<00:00, 303.91it/s]


Epoch 11/100, Loss: 262.837935401946


Epoch 12/100: 100%|██████████| 5170/5170 [00:17<00:00, 302.08it/s]


Epoch 12/100, Loss: 262.6785739688394


Epoch 13/100: 100%|██████████| 5170/5170 [00:17<00:00, 292.81it/s]


Epoch 13/100, Loss: 262.52944835629637


Epoch 14/100: 100%|██████████| 5170/5170 [00:20<00:00, 249.52it/s]


Epoch 14/100, Loss: 262.357546403468


Epoch 15/100: 100%|██████████| 5170/5170 [00:19<00:00, 271.17it/s]


Epoch 15/100, Loss: 262.19379799564297


Epoch 16/100: 100%|██████████| 5170/5170 [00:17<00:00, 289.20it/s]


Epoch 16/100, Loss: 261.88580249144445


Epoch 17/100: 100%|██████████| 5170/5170 [00:17<00:00, 289.45it/s]


Epoch 17/100, Loss: 261.61496530197115


Epoch 18/100: 100%|██████████| 5170/5170 [00:18<00:00, 287.08it/s]


Epoch 18/100, Loss: 261.3702446253203


Epoch 19/100: 100%|██████████| 5170/5170 [00:17<00:00, 287.82it/s]


Epoch 19/100, Loss: 261.17801377279847


Epoch 20/100: 100%|██████████| 5170/5170 [00:17<00:00, 288.18it/s]


Epoch 20/100, Loss: 261.3010341644287


Epoch 21/100: 100%|██████████| 5170/5170 [00:18<00:00, 287.04it/s]


Epoch 21/100, Loss: 260.87199874191725


Epoch 22/100: 100%|██████████| 5170/5170 [00:17<00:00, 288.11it/s]


Epoch 22/100, Loss: 260.6698973179786


Epoch 23/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.63it/s]


Epoch 23/100, Loss: 260.654635398577


Epoch 24/100: 100%|██████████| 5170/5170 [00:17<00:00, 287.57it/s]


Epoch 24/100, Loss: 260.44415007301643


Epoch 25/100: 100%|██████████| 5170/5170 [00:17<00:00, 287.41it/s]


Epoch 25/100, Loss: 260.31889861990453


Epoch 26/100: 100%|██████████| 5170/5170 [00:18<00:00, 284.50it/s]


Epoch 26/100, Loss: 259.86961580863084


Epoch 27/100: 100%|██████████| 5170/5170 [00:18<00:00, 284.03it/s]


Epoch 27/100, Loss: 259.75396165718655


Epoch 28/100: 100%|██████████| 5170/5170 [00:18<00:00, 285.66it/s]


Epoch 28/100, Loss: 259.66082986743123


Epoch 29/100: 100%|██████████| 5170/5170 [00:18<00:00, 284.66it/s]


Epoch 29/100, Loss: 258.98903721785405


Epoch 30/100: 100%|██████████| 5170/5170 [00:18<00:00, 285.40it/s]


Epoch 30/100, Loss: 259.3454405587231


Epoch 31/100: 100%|██████████| 5170/5170 [00:18<00:00, 282.41it/s]


Epoch 31/100, Loss: 259.13170494157066


Epoch 32/100: 100%|██████████| 5170/5170 [00:18<00:00, 281.96it/s]


Epoch 32/100, Loss: 258.98720272735653


Epoch 33/100: 100%|██████████| 5170/5170 [00:18<00:00, 284.97it/s]


Epoch 33/100, Loss: 258.7331502178421


Epoch 34/100: 100%|██████████| 5170/5170 [00:18<00:00, 284.91it/s]


Epoch 34/100, Loss: 258.5669558207809


Epoch 35/100: 100%|██████████| 5170/5170 [00:18<00:00, 283.50it/s]


Epoch 35/100, Loss: 257.7852782397021


Epoch 36/100: 100%|██████████| 5170/5170 [00:17<00:00, 290.78it/s]


Epoch 36/100, Loss: 258.1663801108384


Epoch 37/100: 100%|██████████| 5170/5170 [00:17<00:00, 291.39it/s]


Epoch 37/100, Loss: 257.5791542802143


Epoch 38/100: 100%|██████████| 5170/5170 [00:17<00:00, 291.64it/s]


Epoch 38/100, Loss: 257.33157874740067


Epoch 39/100: 100%|██████████| 5170/5170 [00:17<00:00, 292.35it/s]


Epoch 39/100, Loss: 257.43680192607974


Epoch 40/100: 100%|██████████| 5170/5170 [00:17<00:00, 292.52it/s]


Epoch 40/100, Loss: 257.2217857821998


Epoch 41/100: 100%|██████████| 5170/5170 [00:17<00:00, 292.68it/s]


Epoch 41/100, Loss: 256.6331100275715


Epoch 42/100: 100%|██████████| 5170/5170 [00:17<00:00, 290.32it/s]


Epoch 42/100, Loss: 257.08816581814614


Epoch 43/100: 100%|██████████| 5170/5170 [00:18<00:00, 284.47it/s]


Epoch 43/100, Loss: 255.7817204563945


Epoch 44/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.64it/s]


Epoch 44/100, Loss: 256.0071702534733


Epoch 45/100: 100%|██████████| 5170/5170 [00:17<00:00, 287.25it/s]


Epoch 45/100, Loss: 255.19041711098913


Epoch 46/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.74it/s]


Epoch 46/100, Loss: 256.05263860654554


Epoch 47/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.11it/s]


Epoch 47/100, Loss: 254.58354274539468


Epoch 48/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.81it/s]


Epoch 48/100, Loss: 254.0086112917046


Epoch 49/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.82it/s]


Epoch 49/100, Loss: 254.337232438093


Epoch 50/100: 100%|██████████| 5170/5170 [00:18<00:00, 287.15it/s]


Epoch 50/100, Loss: 253.859877084994


Epoch 51/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.81it/s]


Epoch 51/100, Loss: 254.14693888640267


Epoch 52/100: 100%|██████████| 5170/5170 [00:18<00:00, 285.21it/s]


Epoch 52/100, Loss: 252.9356704556965


Epoch 53/100: 100%|██████████| 5170/5170 [00:18<00:00, 285.95it/s]


Epoch 53/100, Loss: 253.02916465589573


Epoch 54/100: 100%|██████████| 5170/5170 [00:18<00:00, 284.65it/s]


Epoch 54/100, Loss: 252.75419735807054


Epoch 55/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.40it/s]


Epoch 55/100, Loss: 253.1971234493145


Epoch 56/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.76it/s]


Epoch 56/100, Loss: 252.223998045783


Epoch 57/100: 100%|██████████| 5170/5170 [00:18<00:00, 285.73it/s]


Epoch 57/100, Loss: 251.5350153681388


Epoch 58/100: 100%|██████████| 5170/5170 [00:18<00:00, 285.94it/s]


Epoch 58/100, Loss: 250.04023211864714


Epoch 59/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.09it/s]


Epoch 59/100, Loss: 251.24789731212005


Epoch 60/100: 100%|██████████| 5170/5170 [00:18<00:00, 285.99it/s]


Epoch 60/100, Loss: 249.57111327542097


Epoch 61/100: 100%|██████████| 5170/5170 [00:18<00:00, 285.16it/s]


Epoch 61/100, Loss: 251.66071180513333


Epoch 62/100: 100%|██████████| 5170/5170 [00:18<00:00, 282.42it/s]


Epoch 62/100, Loss: 248.351289769252


Epoch 63/100: 100%|██████████| 5170/5170 [00:17<00:00, 291.85it/s]


Epoch 63/100, Loss: 249.24566545468005


Epoch 64/100: 100%|██████████| 5170/5170 [00:17<00:00, 291.46it/s]


Epoch 64/100, Loss: 249.2598870696136


Epoch 65/100: 100%|██████████| 5170/5170 [00:17<00:00, 290.41it/s]


Epoch 65/100, Loss: 247.90420665298257


Epoch 66/100: 100%|██████████| 5170/5170 [00:17<00:00, 291.03it/s]


Epoch 66/100, Loss: 247.2660249584648


Epoch 67/100: 100%|██████████| 5170/5170 [00:17<00:00, 289.17it/s]


Epoch 67/100, Loss: 248.80623962044484


Epoch 68/100: 100%|██████████| 5170/5170 [00:17<00:00, 290.86it/s]


Epoch 68/100, Loss: 246.01000677105074


Epoch 69/100: 100%|██████████| 5170/5170 [00:18<00:00, 272.46it/s]


Epoch 69/100, Loss: 247.6616535255019


Epoch 70/100: 100%|██████████| 5170/5170 [00:18<00:00, 285.13it/s]


Epoch 70/100, Loss: 248.55868882782454


Epoch 71/100: 100%|██████████| 5170/5170 [00:21<00:00, 245.77it/s]


Epoch 71/100, Loss: 245.94747436032765


Epoch 72/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.22it/s]


Epoch 72/100, Loss: 245.39658095970375


Epoch 73/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.39it/s]


Epoch 73/100, Loss: 245.0534360101652


Epoch 74/100: 100%|██████████| 5170/5170 [00:18<00:00, 284.23it/s]


Epoch 74/100, Loss: 246.01517367335308


Epoch 75/100: 100%|██████████| 5170/5170 [00:18<00:00, 280.57it/s]


Epoch 75/100, Loss: 245.10644167156698


Epoch 76/100: 100%|██████████| 5170/5170 [00:19<00:00, 267.40it/s]


Epoch 76/100, Loss: 246.18555370862063


Epoch 77/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.50it/s]


Epoch 77/100, Loss: 243.97447761889813


Epoch 78/100: 100%|██████████| 5170/5170 [00:18<00:00, 285.60it/s]


Epoch 78/100, Loss: 245.58053017741707


Epoch 79/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.37it/s]


Epoch 79/100, Loss: 245.07820915096733


Epoch 80/100: 100%|██████████| 5170/5170 [00:18<00:00, 285.81it/s]


Epoch 80/100, Loss: 243.0607866154418


Epoch 81/100: 100%|██████████| 5170/5170 [00:18<00:00, 283.60it/s]


Epoch 81/100, Loss: 245.4591762350883


Epoch 82/100: 100%|██████████| 5170/5170 [00:18<00:00, 285.94it/s]


Epoch 82/100, Loss: 242.49581006786119


Epoch 83/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.75it/s]


Epoch 83/100, Loss: 242.84695904988155


Epoch 84/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.58it/s]


Epoch 84/100, Loss: 242.0133325152517


Epoch 85/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.84it/s]


Epoch 85/100, Loss: 240.87439401781535


Epoch 86/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.90it/s]


Epoch 86/100, Loss: 241.28867129320332


Epoch 87/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.24it/s]


Epoch 87/100, Loss: 241.01190380561283


Epoch 88/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.53it/s]


Epoch 88/100, Loss: 241.42798582128677


Epoch 89/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.50it/s]


Epoch 89/100, Loss: 242.31770011488658


Epoch 90/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.01it/s]


Epoch 90/100, Loss: 241.0361669431572


Epoch 91/100: 100%|██████████| 5170/5170 [00:17<00:00, 287.25it/s]


Epoch 91/100, Loss: 239.71036654791473


Epoch 92/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.63it/s]


Epoch 92/100, Loss: 241.36197114664074


Epoch 93/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.34it/s]


Epoch 93/100, Loss: 240.03938413874553


Epoch 94/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.71it/s]


Epoch 94/100, Loss: 238.22033875242192


Epoch 95/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.20it/s]


Epoch 95/100, Loss: 239.90139626110084


Epoch 96/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.77it/s]


Epoch 96/100, Loss: 238.15425920652467


Epoch 97/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.01it/s]


Epoch 97/100, Loss: 239.270882395665


Epoch 98/100: 100%|██████████| 5170/5170 [00:18<00:00, 285.80it/s]


Epoch 98/100, Loss: 240.40559251386838


Epoch 99/100: 100%|██████████| 5170/5170 [00:18<00:00, 286.07it/s]


Epoch 99/100, Loss: 236.85643698048546


Epoch 100/100: 100%|██████████| 5170/5170 [00:20<00:00, 254.04it/s]


Epoch 100/100, Loss: 237.4271455178178


In [128]:
pretrained_model_path = '/root/MLXWeek2_copy/weights/skipgram_TechCrunch/model.pt'
embedding_model = torch.load(pretrained_model_path, map_location=torch.device('cpu'))
embedding_model.eval()  # Assuming you're doing inference

# Manually set the embedding_dim if it's not part of the loaded state_dict
embedding_model.embedding_dim = 300

# Manually add the embed method to the loaded model if it's missing
def embed(self, inputs_):
    return self.embeddings(inputs_)

# Assign the method to the instance
embedding_model.embed = embed.__get__(embedding_model, SkipGram_Model)

# Freeze the parameters (weights) of the embedding model
for param in embedding_model.parameters():
    param.requires_grad = False

In [132]:
# Define the path to the combined model
combined_model_path = '/root/MLXWeek2_copy/combined_model.pth'

# Load the combined state dictionary
combined_state = torch.load(combined_model_path, map_location=torch.device('cpu'))

# After loading the state dictionary
print("Keys in the state dict:", combined_state['score_predictor_state_dict'].keys())

# After loading the state dictionary
print("Keys in the state dict:", combined_state['embedding_model_state_dict'].keys())

Keys in the state dict: odict_keys(['embedding_model.embeddings.weight', 'embedding_model.linear.weight', 'embedding_model.linear.bias', 'hidden_1.weight', 'hidden_1.bias', 'hidden_2.weight', 'hidden_2.bias', 'output.weight', 'output.bias'])
Keys in the state dict: odict_keys(['embeddings.weight', 'linear.weight', 'linear.bias'])


In [136]:
score_predictor_model = ScorePredictor(hidden_dim_1, hidden_dim_2, embedding_model)
score_predictor_model.load_state_dict(combined_state['score_predictor_state_dict'])
score_predictor_model.eval()

ScorePredictor(
  (embedding_model): SkipGram_Model(
    (embeddings): Embedding(4000, 300, max_norm=1)
    (linear): Linear(in_features=300, out_features=4000, bias=True)
  )
  (hidden_1): Linear(in_features=300, out_features=64, bias=True)
  (hidden_2): Linear(in_features=64, out_features=32, bias=True)
  (output): Linear(in_features=32, out_features=1, bias=True)
  (relu): ReLU()
)

In [137]:


def score_predictor(title, score_predictor_model):
    # Tokenize the title using the sentencepiece tokenizer
    sp = spm.SentencePieceProcessor(model_file='/root/MLXWeek2_copy/weights/skipgram_TechCrunch/techcrunch_sp.model')
    title_tokens = sp.encode_as_pieces(title)
    title_ids = [sp.piece_to_id(token) for token in title_tokens]

    # Convert to tensor and add batch dimension
    title_tensor = torch.tensor(title_ids).unsqueeze(0)

    # Forward pass through the model
    with torch.no_grad():  # Ensure no gradients are computed
        score = score_predictor_model(title_tensor)

    return score.item()



In [138]:
# Now, use this function to predict the score
title = "The quick brown fox jumps over the lazy dog"
predicted_score = score_predictor(title, score_predictor_model)
print("Predicted score:", predicted_score)

Predicted score: 12.633469581604004
