In [1]:
import json
import os

from PIL import Image
from torch.utils.data import Dataset
from src.preprocess import image_transform
from collections import defaultdict
import torch
import numpy as np
import random
from tqdm import tqdm
from torch.utils.data import Sampler, DataLoader
from src.model import build_model
from src.tokenizer import FoodTokenizer

with open("src/model_configs/baseline.json") as f:
    configs = json.load(f)
    text_cfg = configs["text_cfg"]
    vision_cfg = configs["vision_cfg"]

preprocess = image_transform(vision_cfg["image_size"], is_train=True)


class FoodImageDataset(Dataset):
    def __init__(self,transforms, mode="train"):
        self.dataset_path = "data"
        self.dataset_mode = "train" if mode == "train" else "test"
        self.labels_info_file_name = "labels.json"
        self.train_info_file_name = "train/aihub:1.0_43_0.3_train_crop.json"
        self.test_info_file_name = "test/aihub:1.0_43_0.3_train_crop.json"
        self.labels_file_path = os.path.join(self.dataset_path, self.labels_info_file_name)
        self.train_file_path = os.path.join(self.dataset_path, self.train_info_file_name)
        self.test_file_path = os.path.join(self.dataset_path, self.test_info_file_name)

        self.label_data = None
        self.train_data = None
        self.id_to_text_dict = None
        self.text_to_id_dict = None

        if mode == "train":
            self.labels, self.data = self.get_dataset(self.labels_file_path, self.train_file_path)
        elif mode == "test":
            self.labels, self.data = self.get_dataset(self.labels_file_path, self.test_file_path)

        self.id_to_text_dict = self.get_id_to_text(self.labels)
        self.text_to_id_dict = self.get_text_to_id(self.labels)

        self.data = self.data

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def get_dataset(self, labels_file_path, data_file_path):
        with open(labels_file_path, "r") as file:
            labels = json.load(file)
            labels = labels["categories"]

        with open(data_file_path, "r") as file:
            data = json.load(file)
            data = data["images"]

        return labels, data

    def get_id_to_text(self, label_data):
        return {item["id"]: item["label"] for item in label_data}

    def get_text_to_id(self, label_data):
        return {item["label"]: item["id"] for item in label_data}

    def transform_func(self, examples):
        examples["image"] = [self.preprocess(image) for image in examples["image"]]
        return examples

    def __getitem__(self, idx):
        text_id = self.data[idx]["category_id"]
        text = self.id_to_text_dict[text_id]
        file_name = os.path.split(self.data[idx]["file_name"])[-1]
        file_path = os.path.join(self.dataset_path, self.dataset_mode, file_name)
        image = Image.open(file_path)
        image = self.transforms(image)
        return text, image

In [2]:
train_dataset = FoodImageDataset(preprocess, mode="train")
dset = train_dataset
oversample = False
seed = 42
shuffle = True
epoch = 0 

cls = {}
cls_class = defaultdict(list)

with open("./food_id_to_category_id.json") as f :
    food_to_category = json.load(f)
    
food_to_category = {int(k):int(v) for k,v in food_to_category.items()}

for ind, data in enumerate(dset.data):
    label = data["category_id"]
    if label in cls:
        cls[label].append(ind)
    else:
        cls[label] = [ind]

for key in cls.keys():
    cls_class[food_to_category[key]].extend(cls[key])

cls_inds = [0 for _ in range(len(cls))]
max_n_sample = max([len(samples) for _, samples in cls.items()])
for label, samples in cls.items():
    if oversample:
        pad_size = max_n_sample - len(samples)
        cls[label].extend(samples[:pad_size])

cls_indicies = list(cls.keys())
cls_matcher = {idx: idx for idx in range(len(cls_inds))}

if shuffle:
    g = torch.Generator()
    g.manual_seed(seed)

    for label, indicies in cls.items():
        inds = np.array(indicies)
        cls[label] = inds[torch.randperm(len(indicies), generator=g).tolist()].tolist()

    if epoch != 0:
        cls_inds = np.arange(0, len(cls_inds))
        cls_inds = cls_inds[torch.randperm(len(cls_inds), generator=g).tolist()]
        cls_matcher = {idx: cls_inds[idx] for idx in range(len(cls_inds))}

In [3]:
indicies = []
ind_step = cls_indicies[0]
for _ in range(len(dset)):
    cls_ind = cls_indicies[ind_step % len(cls)]
    cls_ind = cls_matcher[cls_ind]
    smp_ind = cls_inds[cls_ind] % len(cls[cls_ind])

    index = cls[cls_ind][smp_ind]

    ind_step += 1
    cls_inds[cls_ind] += 1
    indicies.append(index)
        
        # hard negative sampling
    cls_index = random.choice(cls[cls_ind])
    while index == cls_index:
        cls_index = random.choice(cls[cls_ind])
    indicies.append(cls_index)
            
    cls_idx = random.choice(cls_class[food_to_category[cls_ind]])
    while index == cls_index or index == cls_idx : 
        cls_idx = random.choice(cls_class[food_to_category[cls_ind]])
    indicies.append(cls_idx)

In [4]:
class ContrastiveSampler(Sampler):
    def __init__(self, dset, shuffle: bool = True, seed: int = 42, oversample=False):
        self.dset = dset
        self.shuffle = shuffle
        self.seed = seed
        self.ind_cls = 0
        self.epoch = 0

        self.cls = {}
        self.cls_class = defaultdict(list)
        for ind, data in enumerate(self.dset.data):
            label = data["category_id"]
            if label in self.cls:
                self.cls[label].append(ind)
            else:
                self.cls[label] = [ind]

        for key in self.cls.keys():
            self.cls_class[food_to_category[key]].extend(self.cls[key])

        self.cls_inds = [0 for _ in range(len(self.cls))]
        self.max_n_sample = max([len(samples) for _, samples in self.cls.items()])
        for label, samples in self.cls.items():
            if oversample:
                pad_size = self.max_n_sample - len(samples)
                self.cls[label].extend(samples[:pad_size])

        self.cls_indicies = list(self.cls.keys())
        self.cls_matcher = {idx: idx for idx in range(len(self.cls_inds))}

        with open("./food_id_to_category_id.json") as f :
            self.food_to_category = json.load(f)
        
        self.food_to_category = {int(k):int(v) for k,v in food_to_category.items()}

    def __iter__(self):
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.seed)

            for label, indicies in self.cls.items():
                inds = np.array(indicies)
                self.cls[label] = inds[torch.randperm(len(indicies), generator=g).tolist()].tolist()

            if self.epoch != 0:
                cls_inds = np.arange(0, len(self.cls_inds))
                cls_inds = cls_inds[torch.randperm(len(self.cls_inds), generator=g).tolist()]
                self.cls_matcher = {idx: cls_inds[idx] for idx in range(len(self.cls_inds))}
            self.epoch += 1

        indicies = []
        pos_indicies = []
        neg_indicies = []
        self.ind_step = self.cls_indicies[0]
        for _ in range(len(self.dset)):
            cls_ind = self.cls_indicies[self.ind_step % len(self.cls)]
            cls_ind = self.cls_matcher[cls_ind]
            smp_ind = self.cls_inds[cls_ind] % len(self.cls[cls_ind])

            index = self.cls[cls_ind][smp_ind]

            self.ind_step += 1
            self.cls_inds[cls_ind] += 1
            indicies.append(index)
        
        # hard negative sampling
            cls_index = random.choice(self.cls[cls_ind])
            while index == cls_index:
                cls_index = random.choice(self.cls[cls_ind])
            
            cls_idx = random.choice(self.cls_class[self.food_to_category[cls_ind]])
            pos_indicies.append(cls_index)

            while index == cls_index or cls_ind == self.dset.data[cls_idx]["category_id"] : 
                cls_idx = random.choice(self.cls_class[self.food_to_category[cls_ind]])
            
            neg_indicies.append(cls_idx)

            indicies.extend(pos_indicies)
            indicies.extend(neg_indicies)

        assert len(indicies) == len(self.dset) * 3
        assert len(self.cls) == len(set([self.dset.data[idx]["category_id"] for idx in indicies]))
        return iter(indicies)

    def __len__(self):
        return len(self.dset)

In [5]:
sampler = ContrastiveSampler(train_dataset)

In [6]:
train_dataloader = DataLoader(train_dataset, batch_size = 2 * 3, shuffle=False, sampler = sampler)

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokens_path = "./src/model_configs/tokens_by_length.json"
tokenizer = FoodTokenizer(tokens_path, configs=configs)
outputs = []
model = build_model(vision_cfg, text_cfg)
model.to(device)
batch_size = 32
outputs_texts = []
outputs_images = []

for texts , images in train_dataloader :
    break
    for index in range(0,batch_size * 3, batch_size):
        texts = texts[index : index + batch_size]
        images = images[index : index + batch_size]
        images = images.to(device, dtype=torch.float32)
        texts = tokenizer(texts).to(device)
        logits_per_image, logits_per_text = model(images, texts)

        outputs_texts.append(logits_per_text)
        outputs_images.append(logits_per_image)
        break

    logits_per_texts = torch.cat(outputs_texts)
    logits_per_images = torch.cat(outputs_images)

In [58]:
model = build_model(vision_cfg, text_cfg)

In [60]:
image.size()

torch.Size([96, 3, 224, 224])

In [61]:
for index in range(0,96,32):
    batch[index : index + 32]

[('배추겉절이',
  '배추겉절이',
  '깍두기',
  '돼지갈비',
  '돼지갈비',
  '불고기',
  '잡탕밥',
  '잡탕밥',
  '열무비빔밥',
  '오징어젓갈',
  '오징어젓갈',
  '무장아찌',
  '곱창전골',
  '곱창전골',
  '달걀국',
  '내장탕',
  '내장탕',
  '소고기전골',
  '삼선우동',
  '삼선우동',
  '물만두',
  '생선가스',
  '생선가스',
  '새우튀김',
  '고구마맛탕',
  '고구마맛탕',
  '고추튀김',
  '탕수육',
  '탕수육',
  '닭강정',
  '수수부꾸미',
  '수수부꾸미',
  '찰떡',
  '물만두',
  '물만두',
  '김치라면',
  '콩조림',
  '콩조림',
  '두부고추장조림',
  '흑미밥',
  '흑미밥',
  '해물덮밥',
  '단무지무침',
  '단무지무침',
  '오이생채',
  '고추장아찌',
  '고추장아찌',
  '오이지',
  '가지나물',
  '가지나물',
  '골뱅이무침',
  '생연어',
  '생연어',
  '생선물회',
  '두부전',
  '두부전',
  '호박전',
  '치킨데리야끼',
  '치킨데리야끼',
  '양념왕갈비',
  '배추김치',
  '배추김치',
  '배추겉절이',
  '불고기덮밥',
  '불고기덮밥',
  '제육덮밥',
  '숙주나물',
  '숙주나물',
  '고구마줄기나물',
  '두부구이',
  '두부구이',
  '햄버거스테이크',
  '깍두기',
  '깍두기',
  '파김치',
  '소고기메추리알장조림',
  '소고기메추리알장조림',
  '알감자조림',
  '우거지해장국',
  '우거지해장국',
  '홍합미역국',
  '떡국',
  '떡국',
  '짬뽕라면',
  '보리밥',
  '보리밥',
  '육회비빔밥',
  '소고기국밥',
  '소고기국밥',
  '샐러드김밥',
  '라볶이',
  '라볶이',
  '마파두부',
  '오리탕',
  '오리탕',
  '배추된장국'),
 tensor([[[[-1.7923, -