#                                                     TF5

### LOAD DATA

In [3]:
import pandas as pd

# Load data
def load_data(source_file, target_file):
    with open(source_file, 'r') as f:
        sources = f.read().splitlines()
    with open(target_file, 'r') as f:
        targets = f.read().splitlines()
    return sources, targets

train_sources, train_targets = load_data('dataset/train.source', 'dataset/train.target')
val_sources, val_targets = load_data('dataset/val.source', 'dataset/val.target')
test_sources, test_targets = load_data('dataset/test.source', 'dataset/test.target')


print("train_sources:", len(train_sources))
print("train_targets:", len(train_targets))
print("val_sources:", len(val_sources))
print("val_targets:", len(val_targets))
print("test_sources:", len(test_sources))
print("test_targets:", len(test_targets))


train_sources: 433034
train_targets: 433034
val_sources: 72304
val_targets: 72304
test_sources: 72940
test_targets: 72940


## Initialization of model

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
import torch
# Initialize the tokenizer and model
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small')


In [None]:
from torch.utils.data import Dataset

class T5Dataset(Dataset):
    def __init__(self, tokenizer, texts, targets, max_length=512, max_target_length=128):
        self.tokenizer = tokenizer
        self.texts = texts
        self.targets = targets
        self.max_length = max_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        source_encoded = self.tokenizer(
            self.texts[idx],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        target_encoded = self.tokenizer(
            self.targets[idx],
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': source_encoded['input_ids'].squeeze(),  # Remove batch dimension
            'attention_mask': source_encoded['attention_mask'].squeeze(),
            'labels': target_encoded['input_ids'].squeeze()
        }

# Apply the dataset class
train_dataset = T5Dataset(tokenizer, train_sources, train_targets)
val_dataset = T5Dataset(tokenizer, val_sources, val_targets)


In [None]:

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=3,              # number of training epochs
    per_device_train_batch_size=8,   # batch size for training
    per_device_eval_batch_size=16,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

# Start training
trainer.train()


# Prediciton

In [None]:
checkpoint_path = "results/checkpoint-162000"

# Load the model from a checkpoint
model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)



In [None]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')

def predict_words(clue, model, k, num_predictions=500):
    # Prepare the model input
    input_ids = tokenizer.encode(clue, return_tensors="pt")
    
    # Ensure the model is in evaluation mode
    model.eval()  
    with torch.no_grad():
        outputs = model.generate(
            input_ids, 
            num_return_sequences=num_predictions, 
            max_length=k + 1,  # Add 1 to accommodate for special tokens
            num_beams=num_predictions,
            early_stopping=False  # Stops generation when all beam hypotheses reached the EOS token
        )
    
    # Decode the predictions and filter by length
    predictions = [tokenizer.decode(output_id, skip_special_tokens=True) for output_id in outputs]
    filtered_predictions = [word for word in predictions if len(word) == k]
    
    return filtered_predictions

# Example usage:
clue = "crossword: city in portugal is:"
word_length = 6
predictions = predict_words(clue,model, word_length)
print(predictions)