In [None]:
import random
import torch
import pickle
import re
import numpy as np
import tqdm.auto as tqdm
from torch.utils.data import Dataset, DataLoader, random_split
from datasets import load_dataset
from PIL import Image
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from transformers import AutoProcessor, BlipForConditionalGeneration

# DO NOT MODIFY THIS LINE
torch.manual_seed(217) # seed to ensure consistent random split of data

# Helper functions

In [None]:
def plot_skeleton_with_im(im, keypoints, connect=True):
    connections = [(0, 1), (1, 2), (3, 4), (4, 5), (10, 11), (11, 12),
                        (13, 14), (14, 15), (2, 6), (3, 6), (6, 7), (7, 8),
                        (8, 9), (12, 7), (13, 7), (8, 12), (8, 13)]
    
    plt.figure(figsize=(8, 8))  # Adjust the figure size as needed
    plt.imshow(im)
    plt.scatter(*zip(*keypoints), s=50, c='b', marker='o')  # Plot keypoints as blue circles

    for i, (x, y) in enumerate(keypoints):
        plt.text(x, y, str(i), fontsize=12, ha='center', va='bottom', color='b')
    
    if connect:
        for connection in connections:
            point1, point2 = connection
            x_values = [keypoints[point1][0], keypoints[point2][0]]
            y_values = [keypoints[point1][1], keypoints[point2][1]]
            plt.plot(x_values, y_values, 'r')
    
    plt.show()

def plot_skeleton(keypoints):
    connections = [(0, 1), (1, 2), (3, 4), (4, 5), (10, 11), (11, 12),
                        (13, 14), (14, 15), (2, 6), (3, 6), (6, 7), (7, 8),
                        (8, 9), (12, 7), (13, 7), (8, 12), (8, 13)]
    
    plt.figure(figsize=(8, 8))  # Adjust the figure size as needed
    plt.scatter(*zip(*keypoints), s=50, c='b', marker='o')  # Plot keypoints as blue circles

    for i, (x, y) in enumerate(keypoints):
        plt.text(x, y, str(i), fontsize=12, ha='center', va='bottom', color='b')

    for connection in connections:
        point1, point2 = connection
        x_values = [keypoints[point1][0], keypoints[point2][0]]
        y_values = [keypoints[point1][1], keypoints[point2][1]]
        plt.plot(x_values, y_values, 'r')
    
    plt.gca().invert_yaxis()
    plt.show()

In [None]:
def get_validation_accuracy(model, val_dataloader, threshold=16):
    model.eval()
    correct = 0
    total = 0
    total_err = 0
    missing_label = 0
    val_loss = 0
    num_batch = 0
    
    progress_bar = tqdm.tqdm(val_dataloader)
    progress_bar.set_description(f"Running validation: ")
    for idx, batch in enumerate(progress_bar):
        num_batch += 1
        input_ids = batch.pop("input_ids").to(device)
        pixel_values = batch.pop("pixel_values").to(device)
        attention_mask = batch.pop("attention_mask").to(device)

        text = processor.batch_decode(input_ids, skip_special_tokens=True)
        
        with torch.no_grad():
            generated_ids = model.generate(pixel_values=pixel_values, max_length=300)
            fw_outputs = model(input_ids=input_ids,
                                pixel_values=pixel_values,
                                labels=input_ids,
                                attention_mask=attention_mask)

        val_loss += fw_outputs.loss.item()
            
        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
        total += len(generated_caption)
        
        for i in range(len(generated_caption)):
            caption = generated_caption[i]
            prediction = np.array(re.findall(r'(\d+), (\d+)', caption), dtype=float)
            
            if prediction.shape[0] != 16:
                missing_label += 1
                continue
            
            t = text[i]
            truth = np.array(re.findall(r'(\d+), (\d+)', t), dtype=float)

            err = np.abs(prediction - truth)
            avg_err = np.mean(err)
            total_err += avg_err
            
            if avg_err < threshold:
                correct += 1

    print("Validation accuracy:", correct / total)
    print("Validation loss:", val_loss / num_batch)
    print("Total:", total)
    print("Correct:", correct)
    print("Missing label:", missing_label)
    print("Average error:", total_err / (total - missing_label))
    
    model.train()
    
    return correct / total, val_loss / num_batch, total_err / (total - missing_label)

In [None]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, data_dict, processor):
        self.processor = processor
        self.images = [Image.open(f"/kaggle/input/28231999/processed_data/{name}") for name in data_dict['img_name']]
        self.text = [text for text in data_dict['text']]
    
    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        image = self.images[idx]
        text = self.text[idx]
        encoding = self.processor(images=image, text=text, padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        return encoding

# Load the preprocessed MPII dataset

In [None]:
with open('/kaggle/input/28231999/processed_data/processed_data.pkl', 'rb') as file:
    # Load the data from the file
    data = pickle.load(file)

# Remove bad points from the dataset

In [None]:
joint_coords = np.zeros((len(data['text']), 16, 2))
bad_groups = []
for i, joints in enumerate(data['text']):
    coords = np.array(re.findall(r'(\d+),(\d+)', joints), dtype=int)
    if coords.shape[0] != 16:
        print(i)
        bad_groups.append(i)
    else:
        joint_coords[i] = coords

# dict_keys(['img_name', 'joint_points', 'text', 'simple_text', 'visibility_text', 'simple_visibility_text', 'normalized_text', 'normalized_simple_text'])
processed_data = {}
processed_data['img_name'] = [value for i, value in enumerate(data['img_name']) if i not in bad_groups]
processed_data['text'] = [value for i, value in enumerate(data['text']) if i not in bad_groups]
processed_data['joint_points'] = [value for i, value in enumerate(data['joint_points']) if i not in bad_groups]
processed_data['simple_text'] = [value for i, value in enumerate(data['simple_text']) if i not in bad_groups]

# Visualize a data point

In [None]:
idx = 0
image = Image.open(f"/kaggle/input/28231999/processed_data/{data['img_name'][idx]}")
plt.figure(figsize=(8, 8)) 
plt.imshow(image)
plt.plot(joint_coords[idx][:, 0], joint_coords[idx][:, 1], 'o')

# Load the HuggingFace BLIP Captioning model

In [None]:
# Rerun this cell to reload model if you want to fine-tune with a different 
# set of parameters
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Hyperparameters

In [None]:
batch_size = 4
lr = 2e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)


lambda1 = lambda epoch: 0.2 if epoch >= 15 else 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1, verbose=True)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, min_lr=1e-5, threshold=0.1, patience=5, verbose=True)

# Initialize dataloaders

In [None]:
dataset = ImageCaptioningDataset(processed_data, processor)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
split_sizes = [train_size, val_size]

train_dataset, val_dataset = random_split(dataset, split_sizes)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size)

running_losses = []
val_accuracies = []
val_losses = []
avg_errs = []

# Training loop

In [None]:
num_epoch = 25
# checkpoints = [1, 15, 23]

print("Number of epochs:", num_epoch)
print("Batch size:", train_dataloader.batch_size)
print("Learning rate:", optimizer.param_groups[0]['lr'])
print("Optimizer:", optimizer)

print("Checkpoints will be save at epoch:", checkpoints)

print("Start training on device:", device)
model.train()
for epoch in range(num_epoch):
    print("Epoch:", epoch)
    progress_bar = tqdm.tqdm(train_dataloader)
    for idx, batch in enumerate(progress_bar):
        
        input_ids = batch.pop("input_ids").to(device)
        pixel_values = batch.pop("pixel_values").to(device)
        attention_mask = batch.pop("attention_mask").to(device)
        
        optimizer.zero_grad()
        
        outputs = model(input_ids=input_ids,
                        pixel_values=pixel_values,
                        labels=input_ids,
                        attention_mask=attention_mask)

        loss = outputs.loss

        progress_bar.set_description(f"Loss: {loss.item()}")
        running_lzosses.append(loss.item())
        
        loss.backward()
        optimizer.step()
        
        # uncomment to initiate validation during training
#     if (epoch % 4 == 0 and epoch != 0) or epoch == 1:
#         val_accuracy, val_loss, avg_err = get_validation_accuracy(model, val_dataloader, 25)
#         val_accuracies.append(val_accuracy)
#         val_losses.append(val_loss)
#         avg_errs.append(avg_err)
        
        # uncomment to take checkpoints
#     if epoch in checkpoints:
#         torch.save(model.state_dict(), '/kaggle/working/model_early.pt'.format(epoch))
    
    # uncomment to use lr scheduler
#     scheduler.step()
        
print("Finish trainig")

# Plot and save training and validation curves

In [None]:
plt.plot(running_losses)
plt.title('Training Loss')
plt.xlabel('Number of iterations')
plt.ylabel('Loss')
plt.savefig('/kaggle/working/training_loss.jpg')
plt.show()

plt.plot(val_accuracies)
plt.title('Validation Accuracy')
plt.xlabel('Number of iterations')
plt.ylabel('Accuracy')
plt.savefig('/kaggle/working/validation_accuracy.jpg')
plt.show()

plt.plot(val_losses)
plt.title('Validation Loss')
plt.xlabel('Number of iterations')
plt.ylabel('Loss')
plt.savefig('/kaggle/working/validation_loss.jpg')
plt.show()

plt.plot(avg_errs)
plt.title('Average Error')
plt.xlabel('Number of iterations')
plt.ylabel('Error')
plt.savefig('/kaggle/working/avg_err.jpg')
plt.show()

In [None]:
with open('/kaggle/working/training_loss.pkl', 'wb') as file:
    pickle.dump(running_losses, file)

with open('/kaggle/working/validation_accuracy.pkl', 'wb') as file:
    pickle.dump(val_accuracies, file)

with open('/kaggle/working/validation_loss.pkl', 'wb') as file:
    pickle.dump(val_losses, file)

with open('/kaggle/working/avg_err.pkl', 'wb') as file:
    pickle.dump(avg_errs, file)

In [None]:
# save a trained model
torch.save(model.state_dict(), '/kaggle/working/model_epoch_final.pt')

In [None]:
# load a trained model
model.load_state_dict(torch.load('/kaggle/input/mpiistartmodel/start_model/late.pt'))
model.to(device)

# Validation

In [None]:
val_accuracy, val_loss, avg_err = get_validation_accuracy(model, val_dataloader, 1)

# Inference on a random sample from the validation set

In [None]:
random_val_dataloader = DataLoader(val_dataset, shuffle=True, batch_size=1)

In [None]:
model.eval()

index = val_dataset.indices[np.random.randint(0, len(val_dataset.indices))]

print(index)

sample = val_dataset.dataset[index]

input_ids = sample["input_ids"].to(device)
pixel_values = sample["pixel_values"].to(device)
text = processor.decode(input_ids, skip_special_tokens=True)
processed_image = np.transpose(pixel_values.cpu().numpy(), (1, 2, 0))
image = val_dataset.dataset.images[index]
print(text)

with torch.no_grad():
    generated_ids = model.generate(pixel_values=pixel_values.unsqueeze(0), max_length=300)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_caption)

# Convert the pairs to a NumPy array
prediction = np.array(re.findall(r'(\d+), (\d+)', generated_caption), dtype=int)
truth = np.array(re.findall(r'(\d+), (\d+)', text), dtype=int)
# plot_skeleton_with_im(processed_image, prediction, False)
plot_skeleton_with_im(image, prediction, False)
plot_skeleton_with_im(image, truth, False)

# Inference on a user input file

In [None]:
image = Image.open(f"FILE_PATH").resize((384, 384))

model.eval()
inputs = processor(images=image, return_tensors="pt").to(device)
pixel_values = inputs.pixel_values
# raw_values = transform(image).unsqueeze(0).to(device)
processed_image = np.transpose(pixel_values[0].cpu().numpy(), (1, 2, 0))
print()

with torch.no_grad():
    generated_ids = model.generate(pixel_values=pixel_values, max_length=300)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_caption)

# Convert the pairs to a NumPy array
prediction = np.array(re.findall(r'(\d+), (\d+)', generated_caption), dtype=int)
plot_skeleton_with_im(processed_image, prediction, False)
plot_skeleton_with_im(image, prediction, False)