setup

In [2]:
import os

In [3]:
run_in_colab = True

In [None]:
if run_in_colab:
  !pip install transformers
  !pip install wandb

  from google.colab import drive
  drive.mount('/content/drive')
  
  !git clone https://github.com/nofarmordehai/Learn-Chess-Commentary.git 'chess'
  CODE_DIR = 'chess'
  os.chdir(f'./{CODE_DIR}')

In [5]:
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import torch
import time
from Models.GPT2 import GPT2
from Dataset.MovesDataset import MovesDataset
from Configs.train_config import config

In [None]:
import wandb

wandb.login()

In [7]:
if run_in_colab:
  BASE_PATH = '/content/drive/MyDrive/NLP/'
else:
  BASE_PATH = '/home/joberant/nlp_fall_2021/nofarm/chess/'

In [8]:
games_data_path = BASE_PATH + 'Data/NEW_fix/games_data'
saved_models_path = BASE_PATH + 'Models/'

GPT2


In [9]:
gpt2 = GPT2()

In [10]:
gpt2.model = gpt2.model.train()

Dataset

In [11]:
dataset = MovesDataset([f'{games_data_path}{i+1}.p' for i in range(config['NUMER_OF_DATA_DIRS'])], gpt2.tokenizer) 

In [12]:
train_size = int(config['train_precentege'] * len(dataset))
test_size = len(dataset) - train_size

In [13]:
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [14]:
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

validation text

In [None]:
run = wandb.init(project="LmChess", config={'batch size': config['batch_size'], 'lr': config['lr'], 'epochs': config['epochs']})

In [16]:
validation_proccessed_data, validation_attn_masks, validation_labels = next(iter(test_dataloader))

In [21]:
validation_input_encodings = []
for i in range(config['batch_size']):
  textual_validation_data = gpt2.tokenizer.decode(token_ids = validation_proccessed_data[i], skip_special_tokens=False).split('<comment> ')

  validation_target_text = textual_validation_data[1].split(' <|endoftext|>')[0]
  validation_input_text = textual_validation_data[0] 

  wandb.log({f"validation_target_text {i}": wandb.Html(f'<p>{validation_target_text}</p>')})
  wandb.log({f"validation_input_text {i}": wandb.Html(f'<p>{validation_input_text}</p>')})

  comment_idx = list(validation_proccessed_data[i]).index(dataset.comment_encoding) + 1
  validation_input_encoding = validation_proccessed_data[0][:comment_idx].unsqueeze(0).cuda()
  #validation_input_encoding  = gpt2.tokenizer.encode(validation_input_text, return_tensors="pt").cuda()
  
  validation_input_encodings.append(validation_input_encoding)

Train

In [22]:
optimizer = AdamW(gpt2.model.parameters(), lr= config['lr'])
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=5000, num_training_steps=-1
)
loss = 0

In [None]:
epochs = config['epochs']

for epoch in range(epochs):
    with tqdm(total=len(train_dataset) / 2) as pbar:
        for idx,entry in enumerate(train_dataloader):

            if idx % 2000 == 0 and idx != 0:
              for i in range(config['batch_size']):
                with torch.no_grad():
                    outputs = gpt2.model.generate(validation_input_encodings[i], num_beams=2, no_repeat_ngram_size=2, max_length=769)
                    output_text = gpt2.tokenizer.decode(outputs[0], skip_special_tokens=True)
                wandb.log({f"output_text {i}": wandb.Html(f'<p>{output_text}</p>')})
            
            if idx % 50000 == 0:
              torch.save(gpt2.model.state_dict(), f'{saved_models_path}{idx}_{time.time()}_{int(loss)}.bin')

            gpt2.model.zero_grad()

            inputs = entry[0].cuda()
            attn_masks = entry[1].cuda()
            labels = entry[2].cuda()
            outputs = gpt2.model(inputs, labels=labels, attention_mask = attn_masks)

            loss = outputs['loss']
            loss.backward()
            optimizer.step()
            scheduler.step()

            wandb.log({"epoch": epoch, "loss": loss})
            pbar.update(2)