In [None]:
%matplotlib inline

import wids
import matplotlib.pyplot as plt
import torch.utils.data
import torch.nn
from random import randrange
import os
os.environ["WDS_VERBOSE_CACHE"] = "1"
os.environ["GOPEN_VERBOSE"] = "0"

# Training data preparation

In [None]:
train_url = "https://storage.googleapis.com/webdataset/fake-imagenet/imagenet-train.json"

dataset = wids.ShardListDataset(train_url)

sample = dataset[1900]

print(sample.keys())
print(sample[".txt"])
plt.imshow(sample[".jpg"])

In [None]:
from datasets import load_dataset

ds = load_dataset("laion/laion-art")

# Save the 'train' split of the dataset as a Parquet file
ds["train"].to_parquet("train_dataset.parquet")

In [None]:
from webdataset.tariterators import (
	base_plus_ext,
	tar_file_expander,
	url_opener,
	valid_sample,
)
import logging

def log_and_continue(exn):
	"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
	logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
	return True

def group_by_keys_nothrow(
	data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None
):
	"""Return function over iterator that groups key, value pairs into samples.

	:param keys: function that splits the key into key and extension (base_plus_ext)
	:param lcase: convert suffixes to lower case (Default value = True)
	"""
	
	# print("DATA: ", data)
	
	current_sample = None
	for filesample in data:
		assert isinstance(filesample, dict)
		
		fname, value = filesample["fname"], filesample["data"]
		
		prefix, suffix = keys(fname)
		if prefix is None:
			continue
		if lcase:
			suffix = suffix.lower()
		# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
		#  this happening in the current LAION400m dataset if a tar ends with same prefix as the next
		#  begins, rare, but can happen since prefix aren't unique across tar files in that dataset
		if (
			current_sample is None
			or prefix != current_sample["__key__"]
			or suffix in current_sample
		):
			if valid_sample(current_sample):
				yield current_sample
			current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
		if suffixes is None or suffix in suffixes:
			current_sample[suffix] = value
	if valid_sample(current_sample):
		yield current_sample

# def tarfile_to_samples_nothrow(src, handler=log_and_continue):
#     # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
#     streams = url_opener(src, handler=handler)
#     files = tar_file_expander(streams, handler=handler)
#     # print("====> \n\n\n Files: ", files)
#     samples = group_by_keys_nothrow(files, handler=handler)
#     return samples



tar_src = "/home/azureuser/maund/open_flamingo/modifications/VLM_ADNI_DATA/replicate_mmc4/000000001.tar"
streams = url_opener(tar_src, handler=log_and_continue)

for stream in streams:
	print(f"Stream: {stream}")

files = tar_file_expander(streams, handler=log_and_continue)

for file in files:
	print(f"File: {file['fname']}, Size: {len(file['data'])} bytes")


samples = group_by_keys_nothrow(files, handler=log_and_continue)
samples = list(samples)
print(samples)
# 


# Testing data Preparation

In [10]:
train_anno_vqa = "/home/anhnv16/maund/open-flamingo-3D/open_flamingo/eval/data/textvqa/train_annotations_vqa_format.json"
train_vqa = "/home/anhnv16/maund/open-flamingo-3D/open_flamingo/eval/data/textvqa/train_questions_vqa_format.json"
val_anno_vqa = "/home/anhnv16/maund/open-flamingo-3D/open_flamingo/eval/data/textvqa/val_annotations_vqa_format.json"
val_vqa = "/home/anhnv16/maund/open-flamingo-3D/open_flamingo/eval/data/textvqa/val_questions_vqa_format.json"

In [None]:
def reduce_size_to_n(file_name, n):
	import json
	with open(file_name, "r", encoding="utf-8") as file:
		data = json.load(file)
	
	print(data.keys())
	if "annotations" in data:
		data["annotations"] = data["annotations"][:n]
	
	if "questions" in data:
		data["questions"] = data["questions"][:n]
 
	new_file_name = file_name.split(".")[0] + f"_{n}" + "." + file_name.split(".")[1]

	with open(new_file_name, "w", encoding="utf-8") as file:
		json.dump(data, file, indent=4, ensure_ascii=False)
	
		
reduce_size_to_n(train_anno_vqa, 50)
reduce_size_to_n(train_vqa, 50)
reduce_size_to_n(val_anno_vqa, 50)
reduce_size_to_n(val_vqa, 50)

dict_keys(['annotations', 'info', 'task_type', 'license', 'data_subtype'])
dict_keys(['questions', 'info', 'task_type', 'data_type', 'license', 'data_subtype'])
dict_keys(['annotations', 'info', 'task_type', 'license', 'data_subtype'])
dict_keys(['questions', 'info', 'task_type', 'data_type', 'license', 'data_subtype'])


# Extended Testing Data Preparation

In [1]:
train_anno_vqa = "/home/anhnv16/maund/open-flamingo-3D/open_flamingo/eval/data/textvqa/train_annotations_vqa_format.json"
train_vqa = "/home/anhnv16/maund/open-flamingo-3D/open_flamingo/eval/data/textvqa/train_questions_vqa_format.json"
val_anno_vqa = "/home/anhnv16/maund/open-flamingo-3D/open_flamingo/eval/data/textvqa/val_annotations_vqa_format.json"
val_vqa = "/home/anhnv16/maund/open-flamingo-3D/open_flamingo/eval/data/textvqa/val_questions_vqa_format.json"

In [6]:
def add_images(file_name, n):
	import json
	with open(file_name, "r", encoding="utf-8") as file:
			data = json.load(file)
  
	if "annotations" in data:
		for anno in data["annotations"]:
			anno["image_ids"] = [anno["image_id"]] * n

	if "questions" in data:
		for quest in data["questions"]:
			quest["image_ids"] = [quest["image_id"]] * n
  
	new_file_name = file_name.split(".")[0] + f"_{n}_extended" + "." + file_name.split(".")[1]
 
	with open(new_file_name, "w", encoding="utf-8") as file:
		json.dump(data, file, indent=4, ensure_ascii=False)
  

add_images(train_anno_vqa, 3)
add_images(train_vqa, 3)
add_images(val_anno_vqa, 3)
add_images(val_vqa, 3)