In [1]:
#Import necessary libraries
import numpy as np
import torch
from pkg_resources import packaging
import numpy as np
import matplotlib.pyplot as plt


print("Torch version:", torch.__version__)

from transformers import AutoImageProcessor, FlavaImageModel, FlavaModel, FlavaFeatureExtractor
from datasets import load_dataset




Torch version: 2.0.0+cu117


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#check for GPU
torch.cuda.is_available()

True

In [3]:
#load the MNIST dataset using the load_dataset package
dataset = load_dataset("gorar/A-MNIST")

Found cached dataset a-mnist (/home/IAIS/jraghu/.cache/huggingface/datasets/gorar___a-mnist/amnist/1.1.0/49d6e25269c73523fbcc8d636818270c5604ddfdd1568ccabdcb39dc4416e954)
100%|██████████| 2/2 [00:00<00:00,  5.86it/s]


In [4]:
#oad CIFAR10 dataset also
cifar10 = load_dataset("cifar10")

Found cached dataset cifar10 (/home/IAIS/jraghu/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)
100%|██████████| 2/2 [00:00<00:00,  5.47it/s]


In [5]:
image = dataset["test"]["image"][0]
print(np.shape(image))

#Since the MNIST dataset is 2 dimensional and FLAVA needs 3 channel input, converting the MNIST data into 3 channel dataset
# Expand dimensions to add a channel dimension
train_images = np.expand_dims(image, axis=-1)

# Duplicate the grayscale channel into three channels
train_images = np.repeat(train_images, 3, axis=-1)

print(np.shape(train_images))

(28, 28)
(28, 28, 3)


In [6]:
#use this to convert into 3 channel
t = image.convert("RGB")
print(np.shape(t))

(28, 28, 3)


In [7]:
#Call the main pre-trained FLAVA model API's required for building upon the pre-trained model
flava = FlavaModel.from_pretrained("facebook/flava-full").eval()
fe = FlavaFeatureExtractor.from_pretrained("facebook/flava-full")
image_processor = AutoImageProcessor.from_pretrained("facebook/flava-full")
model = FlavaImageModel.from_pretrained("facebook/flava-full")


`text_config_dict` is provided which will be used to initialize `FlavaTextConfig`. The value `text_config["id2label"]` will be overriden.
`multimodal_config_dict` is provided which will be used to initialize `FlavaMultimodalConfig`. The value `multimodal_config["id2label"]` will be overriden.
`image_codebook_config_dict` is provided which will be used to initialize `FlavaImageCodebookConfig`. The value `image_codebook_config["id2label"]` will be overriden.
Some weights of the model checkpoint at facebook/flava-full were not used when initializing FlavaModel: ['mlm_head.transform.LayerNorm.bias', 'mmm_text_head.transform.LayerNorm.bias', 'mmm_text_head.transform.dense.weight', 'image_codebook.blocks.group_2.group.block_1.res_path.path.conv_3.weight', 'mmm_image_head.transform.LayerNorm.bias', 'image_codebook.blocks.group_3.group.block_2.res_path.path.conv_3.bias', 'itm_head.pooler.dense.weight', 'image_codebook.blocks.group_3.group.block_1.id_path.bias', 'image_codebook.blocks.group_4.g

Checking out Image processor class

In [8]:
# inputs = image_processor(train_images, return_tensors="pt")

# with torch.no_grad():
#     outputs = model(**inputs)

# last_hidden_states = outputs.last_hidden_state
# list(last_hidden_states.shape)

Making the MNIST dataset smaller for testing purposes

In [9]:
train_small = dataset["train"]["image"][0:6000]
train_small_label = dataset["train"]["label"][0:6000]

In [10]:
feature_list = []
rgb_images = []

#Iterate throught he MNIST dataset to convert it into 3 channel and append them into a list
for sample, label in zip(train_small, train_small_label):
  # Convert to RGB
  rgb_image = sample.convert('RGB')
      
  # Append the RGB image to the list
  rgb_images.append(rgb_image)
        

In [11]:
#Call the feature extractor on these rgb_images and then get the image features on this using the ViT Base -16 transformer unit
with torch.no_grad():
  image_rgb = fe(rgb_images, return_tensors="pt")
  image_features = flava.get_image_features(**image_rgb)[:, 0, :]

Since the FLAVA paper mentions that they use a L-BFGS based logistic regression classifier as a classifier head on top of the image encoder to use the mebeddings, I have implemented the same from the same source as mentioned in the paper and further ahead is the results.

In [12]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

#Converting the tensor into Numpy 
features = image_features.detach().numpy()

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features, train_small_label, test_size=0.2, random_state=42)

# Create a logistic regression model with L-BFGS optimization
logistic_model = LogisticRegression(solver='lbfgs', max_iter=1000)

# Train the logistic regression model
logistic_model.fit(X_train, y_train)

# Make predictions on the test set
y_pred = logistic_model.predict(X_test)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

Accuracy: 0.9258333333333333


CIFAR-10 with FLAVA-- Doing the whole same procedure as done above for MNISt for CIFAR-10 dataset

In [20]:
image_cifar = cifar10["train"]["img"][0]
tt = image_cifar.convert("RGB")
print(np.shape(tt))

(32, 32, 3)


In [27]:
train_small_cifar = cifar10["train"]["img"][0:1000]
train_small_label_cifar = cifar10["train"]["label"][0:1000]

In [28]:
feature_list_cifar = []
rgb_images_cifar = []

for sample, label in zip(train_small_cifar, train_small_label_cifar):
  # Convert to RGB
  rgb_image = sample.convert('RGB')
      
  # Append the RGB image to the list
  rgb_images_cifar.append(rgb_image)

In [29]:
with torch.no_grad():
  image_rgb_cifar = fe(rgb_images_cifar, return_tensors="pt")
  image_features = flava.get_image_features(**image_rgb_cifar)[:, 0, :]

In [30]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

features_cifar = image_features.detach().numpy()

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features_cifar, train_small_label_cifar, test_size=0.2, random_state=42)

# Create a logistic regression model with L-BFGS optimization
logistic_model = LogisticRegression(solver='lbfgs', max_iter=1000)

# Train the logistic regression model
logistic_model.fit(X_train, y_train)

# Make predictions on the test set
y_pred = logistic_model.predict(X_test)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

Accuracy: 0.95
