In [None]:
# 🧩 Install required packages
!pip install -U transformers
!pip install -q datasets accelerate torchvision

# 🧠 Imports
import os
import json
import shutil
from PIL import Image
import torch
from datasets import Dataset
from transformers import (
    TrOCRProcessor,
    VisionEncoderDecoderModel,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

# ✅ Setup paths
drive_base = '/content/drive/MyDrive'
colab_base = os.path.join(drive_base, 'Colab Notebooks')
local_base = '/content/ocr_training_temp_printed'
drive_dataset = os.path.join(drive_base, 'OCR_dataset')
local_dataset = os.path.join(local_base, 'OCR_dataset')

# ✅ Copy dataset locally if not exists
if not os.path.exists(local_dataset):
    shutil.copytree(drive_dataset, local_dataset)

# 📁 Load JSON array (train/valid)
def load_json_array(path):
    with open(path, 'r') as f:
        try:
            data = json.load(f)
            if not isinstance(data, list):
                raise ValueError(f"❌ Il file {path} non contiene un array JSON.")
            print(f"✅ Caricati {len(data)} esempi da {path}")
            return data
        except Exception as e:
            print(f"❌ Errore nel file {path}: {e}")
            return []

train_data = load_json_array(os.path.join(local_dataset, 'train', 'train_data.json'))
valid_data = load_json_array(os.path.join(local_dataset, 'valid', 'valid_data.json'))

if not train_data:
    raise ValueError("❌ Il dataset di training è vuoto o malformato.")
if not valid_data:
    raise ValueError("❌ Il dataset di validazione è vuoto o malformato.")

dataset = {
    "train": Dataset.from_list(train_data),
    "validation": Dataset.from_list(valid_data)
}

# 🟢 Load TrOCR model and processor from CHECKPOINT
checkpoint_dir = os.path.join(drive_base, "Colab Notebooks", "trocr_finetuned", "checkpoint-XXX")  # 👈 AGGIORNA QUESTO
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")  # o quello che avevi usato
model = VisionEncoderDecoderModel.from_pretrained(checkpoint_dir)

print("✅ Checkpoint caricato correttamente.")

# 🛠️ Configure model
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.vocab_size = model.config.decoder.vocab_size
model.config.max_length = 16

# 🔄 Preprocessing function
def preprocess(example):
    image_path = (
        example["image_path"]
        if os.path.isabs(example["image_path"])
        else os.path.join(local_base, example["image_path"])
    )
    image = Image.open(image_path).convert("RGB")
    pixel = processor(images=image, return_tensors="pt").pixel_values.squeeze(0)

    label_ids = processor.tokenizer(
        example["text"],
        padding="max_length",
        max_length=16,
        truncation=True,
        return_tensors="pt"
    ).input_ids.squeeze(0)

    label_ids[label_ids == processor.tokenizer.pad_token_id] = -100

    return {
        "pixel_values": pixel.numpy(),
        "labels": label_ids.numpy()
    }

# ⚙️ Preprocess datasets
processed_dataset = {
    "train": dataset["train"].map(preprocess),
    "validation": dataset["validation"].map(preprocess)
}

# 🧩 Custom collate function
def custom_collate_fn(features):
    pixel_values = torch.stack([
        torch.tensor(f["pixel_values"]) for f in features
    ])
    labels = torch.nn.utils.rnn.pad_sequence(
        [torch.tensor(f["labels"]) for f in features],
        batch_first=True,
        padding_value=-100
    )
    return {"pixel_values": pixel_values, "labels": labels}

# 🛠️ Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=os.path.join(local_base, "trocr_finetuned"),
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=50,
    learning_rate=5e-6,
    num_train_epochs=10,  # puoi modificare
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),
    resume_from_checkpoint=checkpoint_dir  # 👈 questa è la chiave!
)

# 🏋️‍♂️ Initialize Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset["train"],
    eval_dataset=processed_dataset["validation"],
    data_collator=custom_collate_fn,
    tokenizer=processor.tokenizer
)

# 🚀 Start training
trainer.train(resume_from_checkpoint=checkpoint_dir)

# 📈 Plot training & validation loss
import matplotlib.pyplot as plt

logs = trainer.state.log_history

train_loss = []
val_loss = []
epochs = []

for log in logs:
    if 'loss' in log and 'epoch' in log:
        train_loss.append(log['loss'])
        epochs.append(log['epoch'])
    elif 'eval_loss' in log and 'epoch' in log:
        val_loss.append((log['epoch'], log['eval_loss']))

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_loss, label="Training Loss", marker='o')
if val_loss:
    val_epochs, val_values = zip(*val_loss)
    plt.plot(val_epochs, val_values, label="Validation Loss", marker='o')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training & Validation Loss")
plt.legend()
plt.grid(True)
plt.show()

# 💾 Save final model
trainer.save_model(os.path.join(local_base, "trocr_finetuned"))

# ✅ Copy results back to Drive
shutil.copytree(
    os.path.join(local_base, "trocr_finetuned"),
    os.path.join(colab_base, "trocr_finetuned"),
    dirs_exist_ok=True
)

print("✅ Training completato e salvato su Drive.")
