In [None]:
from transformers import AutoImageProcessor, AutoTokenizer, VisionEncoderDecoderModel
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

In [None]:
!pip install -U accelerate
!pip install -U transformers
!pip install transformers[torch]
!pip install accelerate -U
!pip install rouge_score
!pip install evaluate

In [None]:
# Load the model
image_processor = AutoImageProcessor.from_pretrained("google/vit-large-patch16-224-in21k")
decoder_tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased", add_special_tokens=True)

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained("google/vit-large-patch16-224-in21k", "bert-large-uncased")

In [None]:
batch_size = 1
epochs = 1

In [None]:
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, collate_fn=collate_fn)

In [None]:
from torch.optim import AdamW
optimizer = AdamW(model.parameters(), lr=5e-5)

from transformers import get_scheduler
num_training_steps = epochs * len(train_dataloader)
lr_scheduler = get_scheduler( name="linear", optimizer=optimizer, num_warmup_steps=int(num_training_steps/5), num_training_steps=num_training_steps )

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")
    
model.to(device)

In [None]:
# Train the model
model.train()
progress_bar = tqdm(range(num_training_steps), desc='Training')
for epoch in range(epochs):
    losses = []
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(pixel_values = batch['pixel_values'], labels=batch['caption_token']) 
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        losses.append(loss)
        
        torch.cuda.empty_cache()
        progress_bar.update(1)
        
    if epoch in [2,4,6,8,10]: 
        model.save_pretrained(f"Save_at_{epoch}_epochs.pt")
        print(f'Saved at {epoch}')
        
    print('Epoch: {}, Loss: {}'.format(epoch, sum(losses)))
    
    # Evaluate the model on the training set
    model.eval()
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for batch in train_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(pixel_values=batch['pixel_values'], labels=batch['caption_token'])
            logits = outputs.logits
            predicted_ids = torch.argmax(logits, dim=-1)
            
            # Calculate the accuracy
            correct = (predicted_ids == batch['caption_token']).sum().item()
            total_correct += correct
            total_samples += batch['caption_token'].size(0)
    
    accuracy = total_correct / total_samples
    print(f'Epoch {epoch} Accuracy: {accuracy:.4f}')
    
    model.train()

In [None]:
# Save the final model
model.save_pretrained("final-model.pt")