In [48]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel,TFGPT2LMHeadModel,TextDataset,DataCollatorForLanguageModeling, Trainer, TrainingArguments
import chess
import chess.pgn
import gpt_2_simple as gpt2
import os



In [58]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id)


In [59]:
# Replace 'file.txt' with the path to your text file
dataset = TextDataset(tokenizer=tokenizer, file_path='chess.txt', block_size=32)


In [60]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [61]:
training_args = TrainingArguments(
    output_dir='./results',  # The output directory where the model predictions and checkpoints will be written
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=16,
    save_steps=1000,
    save_total_limit=2,
    prediction_loss_only=True,
    logging_steps=1000,
    logging_dir='./logs',
    gradient_accumulation_steps=4
)

In [62]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset
)

trainer.train()



Step,Training Loss


TrainOutput(global_step=1, training_loss=0.8834662437438965, metrics={'train_runtime': 6.7444, 'train_samples_per_second': 0.741, 'train_steps_per_second': 0.148, 'total_flos': 81653760000.0, 'train_loss': 0.8834662437438965, 'epoch': 1.0})

In [63]:
 trainer.save_model()

In [66]:
 model = GPT2LMHeadModel.from_pretrained("/content/results")

In [70]:
def generate_text(sequence):
    ids = tokenizer.encode(f'{sequence}', return_tensors='pt')
    final_outputs = model.generate(
        ids,
        do_sample=True,
        #max_length=max_length,
        max_new_tokens=3,
        pad_token_id=model.config.eos_token_id,
        top_k=50,
        top_p=0.95,
    )
    return tokenizer.decode(final_outputs[0], skip_special_tokens=True)

In [71]:
generate_text("1. e4 e5 2. Nf3")

'1. e4 e5 2. Nf3 Nc6'

In [22]:
class Game:
    def __init__(self):
        self.board = chess.Board()
        self.board.push_san("e4")
        self.board.push_san("e5")
        self.player = "human"
        self.pgn = "1. e4 e5"
        self.move_count = 2
    
    def get_move(self):
        if self.player == "human":
            move = str(input("Your move: "))
            self.player = "ai"
            self.pgn += " "+str(self.move_count) + "."
            self.move_count += 1

        else:
            print(self.pgn)
            move = generate_text(self.pgn)
            print(move)
            move = move.replace(self.pgn, "")
            move = move.split(" ")[1]
            self.player = "human"
        self.pgn = self.pgn + " "+ move 
        self.board.push_san(move)

game = Game()

In [None]:
while True:
    display(game.board)
    game.get_move()