In [3]:
import random
import os
from PIL import Image
from word2number import w2n
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import ViltProcessor, ViltForQuestionAnswering

  warn(


`split.py`

In [4]:
with open("questions_shape.txt", "r") as file:
    lines = file.readlines()


total_samples = len(lines)
train_samples = int(0.7 * total_samples)
val_samples = int(0.1 * total_samples)
test_samples = total_samples - train_samples - val_samples


random.shuffle(lines)


train_data = lines[:train_samples]
val_data = lines[train_samples:train_samples + val_samples]
test_data = lines[train_samples + val_samples:]

with open("train_data.txt", "w") as file:
    file.writelines(train_data)

with open("val_data.txt", "w") as file:
    file.writelines(val_data)

with open("test_data.txt", "w") as file:
    file.writelines(test_data)


`test_vilt.py`

In [3]:
processor = AutoProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

`fine_tune.py`

In [1]:
def batch_loader(file_path, batch_size=32):
    batch_images = []
    batch_questions = []
    batch_labels = []
    
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            parts = line.strip().split(',')
            if len(parts) == 4:
                question = parts[1]
                answer = parts[2]
                image_path = parts[3]

                try:
                    img = Image.open(image_path).resize((384, 384))
                    batch_images.append(img)
                    batch_questions.append(question)
                    ans1 = w2n.word_to_num(answer)
                    batch_labels.append(str(ans1))

                    if len(batch_images) == batch_size:
                        yield batch_images, batch_questions, batch_labels
                        batch_images, batch_questions, batch_labels = [], [], []
                except IOError:
                    print(f"Error opening image {image_path}")
        if batch_images:  # Yield any remaining data as the last batch
            yield batch_images, batch_questions, batch_labels


def process_and_predict(images, questions):

    encoding = processor(images, questions, return_tensors="pt", padding=True)
    outputs = model(**encoding)
    return outputs

def valid(model, best_acc):
    print("--------------------this is validation------------------------")
    accuracy = []
    model.cpu()
    model.eval()
    for images, questions, labels in batch_loader('data/val_data.txt'):#ata/val_data
        outputs = process_and_predict(images, questions)
        print(outputs.logits)
        probabilities = torch.softmax(outputs.logits, dim=1)
        max_prob_indices = torch.argmax(probabilities, dim=1)
        word_list = []
        for i,k in  enumerate(max_prob_indices):
            print(questions[i])
            print("-------------ground truth --------------")
            print(labels[i])
            print("-------------our answer --------------")
            print(model.config.id2label[int(k)])
            word_list.append(model.config.id2label[int(k)])
        matches = sum(1 for x, y in zip(word_list, labels) if x == y)
        probability = matches / len(labels)
        print("-------------batch accuracy --------------")
        accuracy.append(probability)
        print(probability)

    average = sum(accuracy) / len(accuracy)
    print("-------------average accuracy --------------")
    print(average)
    if average> best_acc:
        best_acc = average
        save_path = "checkpoint/" 
        model.save_pretrained(save_path)
        print("save in:{save_path}")

In [None]:
processor = ViltProcessor.from_pretrained("/data/ryh/xianyu/vilt/")
model = ViltForQuestionAnswering.from_pretrained("/data/ryh/xianyu/vilt/")#.cuda()
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)

In [6]:
for epoch in range(10):
    model.train()
    train_loss = 0.0
    train_acc = 0.0
    avg_train_loss = []
    avg_train_acc = []
    num = 0

    for images, questions, labels in batch_loader('data/train_data.txt', batch_size=32):
        
        #image_list = [transforms.ToTensor()(image) for image in images]
        label_list = [model.config.label2id[label] for label in labels]
        labels_tensor = torch.tensor(label_list)#.cuda()
        
        encoding = processor(images, questions, return_tensors="pt", padding=True, truncation=True)
        for feature, data in encoding.data.items():
            encoding.data[feature] = data#.cuda()

        outputs = model(**encoding)
        logits = outputs.logits
        optimizer.zero_grad()
        loss = loss_function(logits, labels_tensor)
        loss.backward()
        optimizer.step()
        num += len(labels)
       
        train_loss += loss.item() * len(labels)
        ret, predictions = torch.max(logits.data, 1)
        correct_counts = predictions.eq(labels_tensor.data.view_as(predictions))
        acc = torch.mean(correct_counts.type(torch.FloatTensor))
        train_acc += acc.item() * len(labels)
        
        avg_train_loss.append(train_loss)
        avg_train_acc.append(train_acc)
        
    avg_train_loss1 = sum(avg_train_loss) / num
    avg_train_acc1 = sum(avg_train_acc) / num

    print(f"Epoch {epoch+1}, Average Training Loss: {avg_train_loss1}, Average Training Accuracy: {avg_train_acc1}")
    best_acc=0
    valid(model, best_acc)

NameError: name 'model' is not defined

`test.py`

In [None]:
accuracy = []

for images, questions, labels in batch_loader('data/test_data.txt'):
    outputs = process_and_predict(images, questions)
    print(outputs.logits)

    probabilities = torch.softmax(outputs.logits, dim=1)
    max_prob_indices = torch.argmax(probabilities, dim=1)
    word_list = []

    for i,k in  enumerate(max_prob_indices):
        print(questions[i])
        print("-------------ground truth --------------")
        print(labels[i])
        print("-------------our answer --------------")
        print(model.config.id2label[int(k)])
        word_list.append(model.config.id2label[int(k)])

    matches = sum(1 for x, y in zip(word_list, labels) if x == y)
    probability = matches / len(labels)
    print("-------------batch accuracy --------------")
    accuracy.append(probability)
    print(probability)

average = sum(accuracy) / len(accuracy)
print("-------------average accuracy --------------")
print(average)