In [1]:
import os
import numpy as np
import pandas as pd
from mixgen import MixGen
from torchvision import transforms
import random
from tqdm import tqdm
import cv2
from PIL import Image
import json
import torch
os.chdir("../")
os.getcwd()

  from .autonotebook import tqdm as notebook_tqdm


'/workspace'

In [2]:
transform_after_mix = transforms.Compose([
                                            transforms.RandomResizedCrop(256,scale=(0.5, 1.0), interpolation=Image.BICUBIC),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
                                        ])



In [3]:
class MixGen:
    def __init__(self,ann_file,image_root,transform,lam = 0.5):
        self.ann_file   = json.load(open(ann_file,'r'))
        self.image_root = image_root
        self.transform  = transform 
        self.lam  = lam
        
    def __call__(self,ann):
        ann1 = self.ann_file[ann]
        image1 = Image.open(os.path.join(self.image_root,ann1['file_name']))
        text1 = ann1['captions']
        
        ann2 = self.ann_file[np.random.choice(list(self.ann_file.keys()))]
        image2 = Image.open(os.path.join(self.image_root,ann2['file_name']))
        text2 = ann2['captions']
        
        image = self.lam * self.transform(image1) + (1-self.lam) * self.transform(image2)
        text = [text1_i + ' ' + text2_i for text1_i, text2_i in zip(text1, text2)]
        
        return image,text


In [4]:
Mixgen_ = MixGen(ann_file="data/COCO/Annotations/coco_img_info.json", image_root = "./data/COCO/Images", transform=transform_after_mix, lam=0.5)

In [5]:
total_imagenum = 10000
classic_dataset = []
new_dataset = {}
with open("data/COCO/Annotations/coco_img_info.json", "r") as f:
    cocodata = json.load(f)

In [6]:
if not os.path.exists(('data/COCO/pre_mixgen')):
    os.makedirs(('data/COCO/pre_mixgen'))
_= 0 
for img_id in tqdm(list(cocodata.keys())):
    try:
        img, txt = Mixgen_(img_id)
        img = torch.permute(img, (1,2,0)).detach().numpy()
        img = (img - np.min(img)) / (np.max(img) - np.min(img))
        img = img*255
        # save image to data/COCO/pre_mixgen
        cv2.imwrite(f"data/COCO/pre_mixgen/{_:06d}.jpg",  cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
        img_path = os.path.join("data/COCO/pre_mixgen", f"{_:06d}.jpg")
        classic_dataset.extend([{'caption': text, 'image': img_path, 'image_id': f"coco_{_:06d}"} for text in txt])
        new_dataset.setdefault(str(_),[].extend(txt))
        _+=1
    except:
        continue
    if _%100 == 0:
        with open(f"data/COCO/Annotations/pre_mixgen.json", "w") as f:
            json.dump(classic_dataset, f)
        with open(f"data/COCO/Annotations/pre_mixgen_new.json", "w") as f:
            json.dump(new_dataset, f)
    if _ > total_imagenum:
        break

  8%|▊         | 10195/123287 [07:41<1:25:19, 22.09it/s]


In [8]:
classic_dataset[0]

{'caption': 'A restaurant has modern wooden tables and chairs. A pizza that has two slices missing from it.',
 'image': 'data/COCO/pre_mixgen/000000.jpg',
 'image_id': 'coco_000000'}

In [11]:
img

array([[[224.7158  , 226.3084  , 222.0735  ],
        [135.68645 , 143.10316 , 166.4231  ],
        [132.3538  , 124.50435 , 146.01793 ],
        ...,
        [227.09627 , 231.69226 , 235.52235 ],
        [220.90706 , 221.9034  , 225.78354 ],
        [223.28754 , 224.35062 , 226.71103 ]],

       [[134.73428 , 145.06093 , 176.16191 ],
        [169.96513 , 158.27586 , 180.3357  ],
        [192.8176  , 159.25478 , 127.93153 ],
        ...,
        [225.19191 , 236.09723 , 237.84111 ],
        [221.85925 , 221.9034  , 224.85602 ],
        [226.14409 , 227.77672 , 229.49356 ]],

       [[174.24998 , 160.7231  , 176.62567 ],
        [218.52661 , 185.68465 , 147.87294 ],
        [196.15025 , 110.79995 , 120.975235],
        ...,
        [229.47676 , 235.60779 , 242.0149  ],
        [221.38316 , 223.86118 , 227.63853 ],
        [226.6202  , 230.22394 , 233.20358 ]],

       ...,

       [[122.83195 ,  83.39117 ,  88.04874 ],
        [124.26022 ,  81.922844,  91.75876 ],
        [127.59288 ,  