In [5]:
import numpy as np
import os 
import pandas as pd
import cv2
import torch
import matplotlib.pyplot as plt
# from ipywidgets import interact
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision
from torch import nn
import torchsummary
from torch.utils.data import DataLoader
from collections import defaultdict
from torchvision.utils import make_grid

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## Datasets

In [41]:
class Dataset():
    def __init__(self, root, phase, transformer=None):
        self.root=root
        self.phase=phase
        self.transformer=transformer
        self.image_list=sorted(os.listdir(root+"image/"+phase))
        self.des_list=sorted(os.listdir(root+"description/"+phase))
        self.label_list=sorted(os.listdir(root+"label/"+phase))
        
    def __getitem__(self, index):
        img, des, label = self.get_data(index)
        return img, des, label
        
    def __len__(self, ):
        return len(self.image_list)

    def get_data(self, index):
        # label
        label_file_name=self.label_list[index]
        lab_f=open(self.root+"label/"+self.phase+"/"+label_file_name, "r")
        label=lab_f.read()

        # description
        des_file_name=self.des_list[index]
        des_f=open(self.root+"description/"+self.phase+"/"+des_file_name, "r")
        des_text=des_f.read()
        des=des_text.split(" ")

        # image
        img_file_name=self.image_list[index]
        image=cv2.imread(self.root+"image/"+self.phase+"/"+img_file_name)
        img=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if(self.transformer!=None):
            transformed_img=self.transformer(image=img)
            img=transformed_img
        
        return img, des, label
        
    

In [50]:
IMAGE_SIZE=448
transformer = A.Compose([
            A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
            A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ]
    )

In [51]:
root='/home/host_data/nickData/'
train_dataset=Dataset(root=root, phase="train", transformer=transformer)

In [52]:
img, des, label=train_dataset[0]

In [54]:
img['image'].shape

torch.Size([3, 448, 448])

## MODELs