In [1]:
from tokenizers import Tokenizer
import torch.optim as optim
import torch


In [2]:
# FOR NEWOUTPUT.TXT
from nltk.tokenize import word_tokenize
from collections import Counter
import re

max_bytes = 60 * 1024 * 1024

with open('new_output.txt', 'rb') as f:
    raw_bytes = f.read(max_bytes)

text = raw_bytes.decode('utf-8', errors='ignore')

chunk_size = 10_000_000  # About ~10MB per chunk
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]

def stream_nltk_tokens(text_chunks):
    for chunk in text_chunks:
        tokens = word_tokenize(chunk.lower(), preserve_line=True)
        for token in tokens:
            if token.isalpha():
                yield token

word_counts = Counter()
print("Tokenizing chunks with NLTK...")

for i, token in enumerate(stream_nltk_tokens(chunks)):
    word_counts[token] += 1
    if i % 1_000_000 == 0:
        print(f"{i:,} tokens processed")

print("Tokenization complete!")
print(f"Unique words: {len(word_counts)}")

min_freq = 5
max_vocab_size = None

def is_clean_word(word):
    return re.fullmatch(r"[a-z]+", word) is not None

cleaned_word_counts = {
    word: freq for word, freq in word_counts.items()
    if is_clean_word(word) and freq >= min_freq
}



Tokenizing chunks with NLTK...
0 tokens processed
1,000,000 tokens processed
2,000,000 tokens processed
3,000,000 tokens processed
4,000,000 tokens processed
5,000,000 tokens processed
6,000,000 tokens processed
7,000,000 tokens processed
8,000,000 tokens processed
9,000,000 tokens processed
Tokenization complete!
Unique words: 158483


In [12]:
print(len(word_counts.keys()))

286501


In [11]:
word_to_idx = {word: idx for idx, word in enumerate(cleaned_word_counts)}
idx_to_word = {idx: word for word, idx in word_to_idx.items()}


print(f"Final cleaned vocab size: {len(word_to_idx)}")
print("Sample vocab entries:", list(word_to_idx.items())[:10])

Final cleaned vocab size: 45700
Sample vocab entries: [('are', 0), ('you', 1), ('a', 2), ('resident', 3), ('of', 4), ('pinnacle', 5), ('who', 6), ('owns', 7), ('small', 8), ('business', 9)]


In [4]:
from collections import deque
import torch
import gc

def generate_skipgram_pairs_from_stream(token_stream, word_to_idx, window_size=2, chunk_size=100_000):
    pairs_tensor_list = []
    pairs = []
    window = deque(maxlen=2 * window_size + 1)

    for token in token_stream:
        if token not in word_to_idx:
            continue

        window.append(token)

        if len(window) < window.maxlen:
            continue  # fill window first

        center_pos = window_size
        center_word = window[center_pos]

        for i, context_word in enumerate(window):
            if i == center_pos:
                continue
            if context_word in word_to_idx:
                pairs.append([word_to_idx[center_word], word_to_idx[context_word]])

        if len(pairs) >= chunk_size:
            chunk_tensor = torch.tensor(pairs, dtype=torch.long)
            pairs_tensor_list.append(chunk_tensor)
            pairs = []
            gc.collect()

    # Final leftovers
    if pairs:
        chunk_tensor = torch.tensor(pairs, dtype=torch.long)
        pairs_tensor_list.append(chunk_tensor)
        del pairs
        gc.collect()

    return torch.cat(pairs_tensor_list, dim=0)

pairs_tensor = generate_skipgram_pairs_from_stream(
    token_stream=stream_nltk_tokens(chunks),
    word_to_idx=word_to_idx,
    window_size=2,
    chunk_size=100_000
)


In [5]:
torch.save(pairs_tensor, "pairs_tensor_nltk_small.pt")

In [6]:
pairs_tensor = torch.load("pairs_tensor_nltk_small.pt")
print(pairs_tensor.shape)
print(pairs_tensor[:5])


  pairs_tensor = torch.load("pairs_tensor_nltk_small.pt")


torch.Size([38117316, 2])
tensor([[2, 0],
        [2, 1],
        [2, 3],
        [2, 4],
        [3, 1]])


In [7]:
from torch.utils.data import Dataset

class SkipGramDataset(Dataset):
    def __init__(self, pairs_tensor):
        self.pairs = pairs_tensor

    def __len__(self):
        return self.pairs.shape[0]
    
    def __getitem__(self, idx):
        center, context = self.pairs[idx]
        return center, context
    


In [8]:
from torch.utils.data import DataLoader

dataset = SkipGramDataset(pairs_tensor)
dataloader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=0, pin_memory=True, persistent_workers=False)

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import random

class SkipGramNegSampling(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(SkipGramNegSampling, self).__init__()
        self.in_embed = nn.Embedding(vocab_size, embed_dim)
        self.out_embed = nn.Embedding(vocab_size, embed_dim)

    def forward(self, center_words, pos_context_words, neg_context_words):
        center_embeds = self.in_embed(center_words)
        pos_embeds = self.out_embed(pos_context_words)
        neg_embeds = self.out_embed(neg_context_words)

        pos_score = torch.sum(center_embeds * pos_embeds, dim=1)
        pos_loss = F.logsigmoid(pos_score)

        neg_score = torch.bmm(neg_embeds.neg(), center_embeds.unsqueeze(2)).squeeze(2)
        neg_loss = F.logsigmoid(neg_score).sum(1)

        loss = -(pos_loss + neg_loss).mean()
        return loss
    
def get_negative_samples(batch_size, vocab_size, num_neg_samples):
    neg_samples = torch.randint(0, vocab_size, (batch_size, num_neg_samples))
    return neg_samples


In [10]:
vocab_size = len(word_to_idx)
embedding_dim = 256
num_neg_samples = 5
epochs = 5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = SkipGramNegSampling(vocab_size, embedding_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.003)

print(device)

for epoch in range(epochs):
    total_loss = 0
    model.train()
    print(f"epoch {epoch}")
    for center, context in dataloader:
        center = center.to(device, non_blocking=True)
        context = context.to(device, non_blocking=True)

        batch_size = center.shape[0]
        negative_samples = get_negative_samples(batch_size, vocab_size, num_neg_samples).to(device, non_blocking=True)

        optimizer.zero_grad()
        loss = model(center, context, negative_samples)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    average_loss = total_loss / len(dataloader)
    print(f'Epoch {epoch + 1} / {epochs} - Loss: {average_loss:.4f}')
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': average_loss,
        'embedding_dim': embedding_dim,
        'vocab_size': vocab_size
    }

    torch.save(checkpoint, f"checkpoint_epoch_nltk{epoch+1}.pt")
    print("Embeddings checkpoint saved")

cuda
epoch 0
Epoch 1 / 5 - Loss: 3.5666
Embeddings checkpoint saved
epoch 1
Epoch 2 / 5 - Loss: 1.1926
Embeddings checkpoint saved
epoch 2
Epoch 3 / 5 - Loss: 1.0575
Embeddings checkpoint saved
epoch 3
Epoch 4 / 5 - Loss: 1.0186
Embeddings checkpoint saved
epoch 4
Epoch 5 / 5 - Loss: 1.0006
Embeddings checkpoint saved


In [1]:
import torch
print(torch.version.cuda)
print(torch.cuda.is_available()) 
print(torch.cuda.device_count()) 


12.1
True
1


In [13]:
# reload the checkpoint
checkpoint = torch.load('checkpoint_epoch_3.pt', map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

vocab_size = checkpoint['vocab_size']
embedding_dim = checkpoint['embedding_dim']
num_neg_samples = 5

model = SkipGramNegSampling(vocab_size, embedding_dim)
optimizer = optim.Adam(model.parameters(), lr=0.003)

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

model.train()

start_epoch = checkpoint['epoch'] + 1  # resume at the next epoch
epochs = 5  

for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

print(f"Resuming from epoch {start_epoch}")

for epoch in range(start_epoch, epochs):
    total_loss = 0
    model.train()
    print(f"epoch {epoch}")
    for center, context in dataloader:
        center = center.to(device, non_blocking=True)
        context = context.to(device, non_blocking=True)

        batch_size = center.shape[0]
        negative_samples = get_negative_samples(batch_size, vocab_size, num_neg_samples)
        negative_samples = negative_samples.to(device, non_blocking=True)  

        optimizer.zero_grad()
        loss = model(center, context, negative_samples)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    average_loss = total_loss / len(dataloader)
    print(f'Epoch {epoch + 1} / {epochs} - Loss: {average_loss:.4f}')
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': average_loss,
        'embedding_dim': embedding_dim,
        'vocab_size': vocab_size
    }

    torch.save(checkpoint, f"checkpoint_epoch_{epoch+1}.pt")
    print("Embeddings checkpoint saved")

  checkpoint = torch.load('checkpoint_epoch_3.pt', map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))


Resuming from epoch 3
epoch 3
Epoch 4 / 5 - Loss: 0.8494
Embeddings checkpoint saved
epoch 4
Epoch 5 / 5 - Loss: 0.8504
Embeddings checkpoint saved
