In [12]:
#
#
#
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import tqdm
import collections
import more_itertools
import wandb
import torch


#
#
#
torch.manual_seed(42)


#
#
#
with open('text8') as f: text8: str = f.read()


#
#
#
def preprocess(text: str) -> list[str]:
  text = text.lower()
  text = text.replace('.',  ' <PERIOD> ')
  text = text.replace(',',  ' <COMMA> ')
  text = text.replace('"',  ' <QUOTATION_MARK> ')
  text = text.replace(';',  ' <SEMICOLON> ')
  text = text.replace('!',  ' <EXCLAMATION_MARK> ')
  text = text.replace('?',  ' <QUESTION_MARK> ')
  text = text.replace('(',  ' <LEFT_PAREN> ')
  text = text.replace(')',  ' <RIGHT_PAREN> ')
  text = text.replace('--', ' <HYPHENS> ')
  text = text.replace('?',  ' <QUESTION_MARK> ')
  text = text.replace(':',  ' <COLON> ')
  words = text.split()
  stats = collections.Counter(words)
  words = [word for word in words if stats[word] > 5]
  return words


#
#
#
corpus: list[str] = preprocess(text8)
print(type(corpus)) # <class 'list'>
print(len(corpus))  # 16,680,599
print(corpus[:7])   # ['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse']


#
#
#
def create_lookup_tables(words: list[str]) -> tuple[dict[str, int], dict[int, str]]:
  word_counts = collections.Counter(words)
  vocab = sorted(word_counts, key=lambda k: word_counts.get(k), reverse=True)
  int_to_vocab = {ii+1: word for ii, word in enumerate(vocab)}
  int_to_vocab[0] = '<PAD>'
  vocab_to_int = {word: ii for ii, word in int_to_vocab.items()}
  return vocab_to_int, int_to_vocab


#
#
#
words_to_ids, ids_to_words = create_lookup_tables(corpus)
tokens = [words_to_ids[word] for word in corpus]
print(type(tokens)) # <class 'list'>
print(len(tokens))  # 16,680,599
print(tokens[:7])   # [5234, 3081, 12, 6, 195, 2, 3134]


#
#
#
print(ids_to_words[5234])        # anarchism
print(words_to_ids['anarchism']) # 5234
print(words_to_ids['have'])      # 3081
print(len(words_to_ids))         # 63,642


#
#
#
class SkipGramOne(torch.nn.Module):
  def __init__(self, voc, emb, _):
    super().__init__()
    self.emb = torch.nn.Embedding(num_embeddings=voc, embedding_dim=emb)
    self.ffw = torch.nn.Linear(in_features=emb, out_features=voc)
    self.max = torch.nn.Softmax(dim=1)

  def forward(self, inpt, trgs):
    emb = self.emb(inpt)
    out = self.ffw(emb)
    sft = self.max(out)
    return -(sft[0, trgs]).log().mean()


#
#
#
class SkipGramTwo(torch.nn.Module):
  def __init__(self, voc, emb, ctx):
    super().__init__()
    self.ctx = ctx
    self.emb = torch.nn.Embedding(num_embeddings=voc, embedding_dim=emb)
    self.ffw = torch.nn.Linear(in_features=emb, out_features=ctx*voc)
    self.max = torch.nn.Softmax(dim=1)

  def forward(self, inpt, trgs):
    emb = self.emb(inpt)
    hid = self.ffw(emb)
    lgt = hid.view(self.ctx, -1)
    sft = self.max(lgt)
    arg = torch.arange(sft.size(0))
    foo = sft[arg, trgs]
    return -foo.log().mean()


#
#
#
class SkipGramTre(torch.nn.Module):
  def __init__(self, voc, emb, ctx):
    super().__init__()
    self.ctx = ctx
    self.emb = torch.nn.Embedding(num_embeddings=voc, embedding_dim=emb)
    self.ffw = torch.nn.Linear(in_features=emb, out_features=voc, bias=False)
    self.sig = torch.nn.Sigmoid()

  def forward(self, inpt, trgs):
    emb = self.emb(inpt)
    ctx = self.ffw.weight[trgs]
    lgt = torch.mm(ctx, emb.T)
    sig = self.sig(lgt)
    return -sig.log().mean()


#
#
#
class SkipGramFoo(torch.nn.Module):
  def __init__(self, voc, emb, ctx):
    super().__init__()
    self.ctx = ctx
    self.emb = torch.nn.Embedding(num_embeddings=voc, embedding_dim=emb)
    self.ffw = torch.nn.Linear(in_features=emb, out_features=voc, bias=False)
    self.sig = torch.nn.Sigmoid()

  def forward(self, inpt, trgs, rand):
    emb = self.emb(inpt)
    ctx = self.ffw.weight[trgs]
    rnd = self.ffw.weight[rand]
    out = torch.mm(ctx, emb.T)
    rnd = torch.mm(rnd, emb.T)
    out = self.sig(out)
    rnd = self.sig(rnd)
    pst = -out.log().mean()
    ngt = -(1 - rnd).log().mean()
    return pst + ngt


#
#
#
args = (len(words_to_ids), 64, 2)
mOne = SkipGramOne(*args)
mTwo = SkipGramTwo(*args)
mTre = SkipGramTre(*args)
mFoo = SkipGramFoo(*args)


#
#
#
print('mOne', sum(p.numel() for p in mOne.parameters()))
print('mTwo', sum(p.numel() for p in mTwo.parameters()))
print('mTre', sum(p.numel() for p in mTre.parameters()))
print('mFoo', sum(p.numel() for p in mFoo.parameters()))


#
#
#
opOne = torch.optim.Adam(mOne.parameters(), lr=0.003)
opTwo = torch.optim.Adam(mTwo.parameters(), lr=0.003)
opTre = torch.optim.Adam(mTre.parameters(), lr=0.003)
opFoo = torch.optim.Adam(mFoo.parameters(), lr=0.003)


# #
# #
# #
# wandb.init(project='skip-gram', name='mOne')
# for epoch in range(10):
#   wins = more_itertools.windowed(tokens[:10000], 3)
#   prgs = tqdm.tqdm(enumerate(wins), total=len(tokens[:10000]), desc=f"Epoch {epoch+1}", leave=False)
#   for i, tks in prgs:
#     opOne.zero_grad()
#     inpt = torch.LongTensor([tks[1]])
#     trgs = torch.LongTensor([tks[0], tks[2]])
#     loss = mOne(inpt, trgs)
#     loss.backward()
#     opOne.step()
#     wandb.log({'loss': loss.item()})
# wandb.finish()


# #
# #
# #
# wandb.init(project='skip-gram', name='mTwo')
# for epoch in range(10):
#   wins = more_itertools.windowed(tokens[:10000], 3)
#   prgs = tqdm.tqdm(wins, desc=f"Epoch {epoch+1}", leave=False)
#   for i, tks in prgs:
#     inpt = torch.LongTensor([tks[1]])
#     trgs = torch.LongTensor([tks[0], tks[2]])
#     opTwo.zero_grad()
#     loss = mTwo(inpt, trgs)
#     loss.backward()
#     opTwo.step()
#     wandb.log({'loss': loss.item()})
# wandb.finish()


# #
# #
# #
# wandb.init(project='skip-gram', name='mTre')
# for epoch in range(10):
#   wins = more_itertools.windowed(tokens[:10000], 3)
#   prgs = tqdm.tqdm(enumerate(wins), total=len(tokens[:10000]), desc=f"Epoch {epoch+1}", leave=False)
#   for i, tks in prgs:
#     inpt = torch.LongTensor([tks[1]])
#     trgs = torch.LongTensor([tks[0], tks[2]])
#     opTre.zero_grad()
#     loss = mTre(inpt, trgs)
#     loss.backward()
#     opTre.step()
#     wandb.log({'loss': loss.item()})
# wandb.finish()



# Initialize Weights and Biases
wandb.init(project="cbow_training", entity="omareweis123")

# Move the model to the GPU if available
mFoo = mFoo.to(device)

# Training loop
for epoch in range(10):
    wins = more_itertools.windowed(tokens[:10000], 3)
    prgs = tqdm.tqdm(enumerate(wins), total=len(tokens[:10000]), desc=f"Epoch {epoch+1}", leave=False)
    for i, tks in prgs:
        # Move input tensors to the same device (GPU or CPU)
        inpt = torch.LongTensor([tks[1]]).to(device)
        trgs = torch.LongTensor([tks[0], tks[2]]).to(device)
        rand = torch.randint(0, len(words_to_ids), (2,)).to(device)

        # Zero gradients
        opFoo.zero_grad()

        # Forward pass
        loss = mFoo(inpt, trgs, rand)

        # Backward pass and optimization
        loss.backward()
        opFoo.step()

        # Log the loss
        wandb.log({'loss': loss.item()})

# Finish the W&B logging
wandb.finish()


#
#
#
# wandb.init(project="cbow_training", entity="omareweis123")

# # Define your token limit here (e.g., 10,000 tokens)
# token_limit = 10000

# # Split data into 80% train, 20% validation from the limited set
# train_tokens = tokens[:int(0.8 * token_limit)]
# val_tokens = tokens[int(0.8 * token_limit):token_limit]

# # Move the model to the correct device
# mFoo = mFoo.to(device)

# for epoch in range(10):
#     # Training loop
#     mFoo.train()
#     train_wins = list(more_itertools.windowed(train_tokens, 3))  # Create training windows
#     train_prgs = tqdm.tqdm(enumerate(train_wins), total=len(train_wins), desc=f"Epoch {epoch+1} [Train]", leave=False)
#     train_loss_total = 0
#     for i, tks in train_prgs:
#         if None in tks:  # Skip invalid windows (e.g., at the end)
#             continue
#         inpt = torch.LongTensor([tks[1]]).to(device)
#         trgs = torch.LongTensor([tks[0], tks[2]]).to(device)
#         rand = torch.randint(0, len(words_to_ids), (2,)).to(device)  # Move rand to device
#         opFoo.zero_grad()
#         loss = mFoo(inpt, trgs, rand)
#         loss.backward()
#         opFoo.step()
#         train_loss_total += loss.item()
#         wandb.log({'train_loss': loss.item()})

#     avg_train_loss = train_loss_total / len(train_wins)
#     wandb.log({'avg_train_loss': avg_train_loss})

#     # Validation loop
#     mFoo.eval()
#     val_wins = list(more_itertools.windowed(val_tokens, 3))  # Create validation windows
#     val_prgs = tqdm.tqdm(enumerate(val_wins), total=len(val_wins), desc=f"Epoch {epoch+1} [Validation]", leave=False)
#     val_loss_total = 0
#     with torch.no_grad():
#         for i, tks in val_prgs:
#             if None in tks:  # Skip invalid windows
#                 continue
#             inpt = torch.LongTensor([tks[1]]).to(device)
#             trgs = torch.LongTensor([tks[0], tks[2]]).to(device)
#             rand = torch.randint(0, len(words_to_ids), (2,)).to(device)  # Move rand to device
#             val_loss = mFoo(inpt, trgs, rand)
#             val_loss_total += val_loss.item()

#     avg_val_loss = val_loss_total / len(val_wins)
#     wandb.log({'avg_val_loss': avg_val_loss})

# wandb.finish()


<class 'list'>
16680599
['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse']
<class 'list'>
16680599
[5234, 3081, 12, 6, 195, 2, 3134]
anarchism
5234
39
63642
mOne 8209818
mTwo 12346548
mTre 8146176
mFoo 8146176


0,1
loss,▇▂▇▄▅▂▄▇▇▃▄▆▅▇▄█▇▂▅▄▇▄▄▅█▆▅▅▆▄▅▇█▆▄▆▇▄▇▁

0,1
loss,0.60369


                                                               

0,1
loss,▂▃▂▃▂▄▂▄▂▂▂▃▂▁▁▂▂▂█

0,1
loss,
