In [None]:
! pip install transformers torch sentencepiece accelerate

In [2]:
import pandas as pd
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from torch.utils.data import Dataset
import torch
from sklearn.model_selection import train_test_split

# Load the dataset
file_path = 'quantized_coordinates.csv'
coordinates_df = pd.read_csv(file_path)

# Combine y_quant and x_quant into a sequence
coordinates_df['sequence'] = coordinates_df['y_quant'].astype(str) + ' ' + coordinates_df['x_quant'].astype(str)

# Prepare data: We'll create input-output pairs where the input is the first few coordinates, and the output is the full sequence.
def prepare_data(df, input_len=3):
    input_sequences = []
    output_sequences = []
    for i in range(len(df) - input_len):
        input_seq = ' '.join(df['sequence'].iloc[i:i+input_len])
        output_seq = ' '.join(df['sequence'].iloc[i:i+input_len+1])
        input_sequences.append(input_seq)
        output_sequences.append(output_seq)
    return input_sequences, output_sequences

input_seqs, output_seqs = prepare_data(coordinates_df)

# Create a custom dataset class
class CoordinateDataset(Dataset):
    def __init__(self, inputs, outputs, tokenizer, max_len):
        self.inputs = inputs
        self.outputs = outputs
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_seq = self.inputs[idx]
        output_seq = self.outputs[idx]
        
        # Tokenize the input and output sequences
        inputs = self.tokenizer(input_seq, max_length=self.max_len, padding='max_length', truncation=True, return_tensors="pt")
        outputs = self.tokenizer(output_seq, max_length=self.max_len, padding='max_length', truncation=True, return_tensors="pt")

        return {
            'input_ids': inputs.input_ids.flatten(),
            'attention_mask': inputs.attention_mask.flatten(),
            'labels': outputs.input_ids.flatten()
        }

# Initialize the tokenizer and model
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# Dataset parameters
MAX_LEN = 50
BATCH_SIZE = 8


# Split the data into train and validation sets (80% train, 20% validation)
train_inputs, val_inputs, train_outputs, val_outputs = train_test_split(
    input_seqs, output_seqs, test_size=0.2, random_state=42
)

# Create train and eval datasets
train_dataset = CoordinateDataset(train_inputs, train_outputs, tokenizer, max_len=MAX_LEN)
eval_dataset = CoordinateDataset(val_inputs, val_outputs, tokenizer, max_len=MAX_LEN)

# Training arguments (updated)
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=10,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="steps",  # Evaluate at each logging step
    eval_steps=50,  # Evaluate every 50 steps
    save_steps=500,  # Save model every 500 steps
)

# Initialize Trainer with evaluation dataset
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,  # Add the evaluation dataset
)

# Train the model
trainer.train()


# Inference: Generate a sequence given the first few points
def generate_sequence(model, tokenizer, input_sequence, max_length=50):
    inputs = tokenizer(input_sequence, return_tensors="pt")
    output = model.generate(inputs.input_ids, max_length=max_length)
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Example prediction
input_sequence = '88 121 89 121 90'
predicted_sequence = generate_sequence(model, tokenizer, input_sequence)
print(f"Input: {input_sequence}")
print(f"Predicted Sequence: {predicted_sequence}")


                                                 
  1%|          | 10/1830 [02:27<12:42,  2.39it/s]

{'loss': 9.3745, 'grad_norm': 168.86695861816406, 'learning_rate': 4.965986394557823e-05, 'epoch': 0.07}


                                                 
  1%|          | 10/1830 [02:32<12:42,  2.39it/s]

{'loss': 4.251, 'grad_norm': 41.83161544799805, 'learning_rate': 4.931972789115647e-05, 'epoch': 0.14}


                                                 
  1%|          | 10/1830 [02:36<12:42,  2.39it/s]

{'loss': 1.788, 'grad_norm': 5.38971471786499, 'learning_rate': 4.89795918367347e-05, 'epoch': 0.2}


                                                 
  1%|          | 10/1830 [02:40<12:42,  2.39it/s]

{'loss': 1.2663, 'grad_norm': 4.641617298126221, 'learning_rate': 4.8639455782312926e-05, 'epoch': 0.27}


                                                 
  1%|          | 10/1830 [02:44<12:42,  2.39it/s]

{'loss': 1.0262, 'grad_norm': 3.3070597648620605, 'learning_rate': 4.8299319727891155e-05, 'epoch': 0.34}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 

[A[A                                         
  1%|          | 10/1830 [02:47<12:42,  2.39it/s]
[A
[A

{'eval_loss': 0.5146335363388062, 'eval_runtime': 3.1573, 'eval_samples_per_second': 92.8, 'eval_steps_per_second': 11.719, 'epoch': 0.34}


                                                 
  1%|          | 10/1830 [02:51<12:42,  2.39it/s]

{'loss': 0.8014, 'grad_norm': 2.6286284923553467, 'learning_rate': 4.795918367346939e-05, 'epoch': 0.41}


                                                 
  1%|          | 10/1830 [02:56<12:42,  2.39it/s]

{'loss': 0.6588, 'grad_norm': 10.684414863586426, 'learning_rate': 4.761904761904762e-05, 'epoch': 0.48}


                                                 
  1%|          | 10/1830 [03:02<12:42,  2.39it/s]

{'loss': 0.5002, 'grad_norm': 1.4552953243255615, 'learning_rate': 4.7278911564625856e-05, 'epoch': 0.54}


                                                 
  1%|          | 10/1830 [03:08<12:42,  2.39it/s]

{'loss': 0.4799, 'grad_norm': 2.1467034816741943, 'learning_rate': 4.6938775510204086e-05, 'epoch': 0.61}


