In [1]:
import sys
sys.path.append('/home/daniel/gitrepos/vision/references/detection')
import torch
from engine import train_one_epoch, evaluate
import utils
import transforms as T
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torch.utils.data
from PIL import Image, ImageFile
import pandas as pd
from tqdm import tqdm

import collections
import os
import numpy as np
from sklearn.preprocessing import LabelEncoder
from torchvision import transforms

ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
csv_dir = "/hdd/open-images/csvs"
train_df = pd.read_csv(f"{csv_dir}/train-annotations-bbox.csv")

In [3]:
# label_encoder = LabelEncoder()
# train_df["LabelEncoded"] = label_encoder.fit_transform(train_df["LabelName"])
# print("Finished encoding labels")
# collapsed_df = train_df.groupby('ImageID', as_index=False).agg(lambda x: " ".join(x.astype(str)))

In [4]:
# print("Original length: ", len(train_df))
# print("Collapsed length: ", len(collapsed_df))
# collapsed_df.to_csv(f"{csv_dir}/train-annotations-bbox-collapsed.csv", index=False)

In [5]:
# torch.save(label_encoder,"label_encoder.bin")
label_encoder = torch.load("label_encoder.bin")

In [16]:
def get_instance_segmentation_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

class OpenDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, df_path, height, width, img_transforms=None):
        self.transforms = img_transforms
        self.image_dir = image_dir
        self.df = pd.read_csv(df_path)
        self.df = self.df[:1000]
        self.height = height
        self.width = width
        self.image_info = collections.defaultdict(dict)
        
        # Filling up image_info is left as an exercise to the reader
        
        counter = 0
        for index, row in tqdm(self.df.iterrows(), total=len(self.df)):
            image_id = row["ImageID"]
            image_path = os.path.join(self.image_dir, image_id)

            if os.path.exists(image_path + '.jpg'):
                self.image_info[counter]["image_id"] = image_id
                self.image_info[counter]["image_path"] = image_path
                self.image_info[counter]["XMin"] = row["XMin"]
                self.image_info[counter]["YMin"] = row["YMin"]
                self.image_info[counter]["XMax"] = row["XMax"]
                self.image_info[counter]["YMax"] = row["YMax"]
                self.image_info[counter]["labels"] = row["LabelEncoded"]
                counter += 1

    def __getitem__(self, idx):
        # load images ad masks
        img_path = self.image_info[idx]["image_path"] + ".jpg"
        img = Image.open(img_path).convert("RGB")
        img = img.resize((self.width, self.height), resample=Image.BILINEAR)
        
        # processing part and extraction of boxes is left as an exercise to the reader
        # get bounding box coordinates for each mask         
        num_objs = len(self.image_info[idx]["labels"].split())
        
        boxes = []
        xmins = self.image_info[idx]["XMin"].split()
        ymins = self.image_info[idx]["YMin"].split()
        xmaxs = self.image_info[idx]["XMax"].split()
        ymaxs = self.image_info[idx]["YMax"].split()
        
        assert len(xmins) == len(ymins) == len(xmaxs) == len(ymaxs) == num_objs
        
        for i in range(num_objs):
            xmin = float(xmins[i])
            xmax = float(xmaxs[i])
            ymin = float(ymins[i])
            ymax = float(ymaxs[i])
            boxes.append([xmin, ymin, xmax, ymax])
                                                
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor([int(x) for x in self.image_info[idx]["labels"].split()])

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
                                                              
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area

        if self.transforms is not None:
            img, target = self.transforms(img, target)
        else:
            img = transforms.ToTensor()(img)

#         print(img)
#         print(target)
        return img, target

    def __len__(self):
        return len(self.image_info)
    

In [7]:
csv_dir = "/hdd/open-images/csvs"
train_csv = pd.read_csv(f"{csv_dir}/train-annotations-bbox.csv")
label_counts = train_csv["LabelName"].value_counts()
num_classes = len(label_counts)
# num_greater_than_50 = len(label_counts[label_counts >= 50])
# print("Number of classes with at least 50 examples: ", num_greater_than_50)
# num_classes = num_greater_than_50
print(num_classes)

599


In [17]:
device = torch.device('cuda:1')

train_images_dir = "data/unzipped/all_train"
train_csv_file = "data/csvs/train-annotations-bbox-collapsed.csv"
dataset_train = OpenDataset(train_images_dir, train_csv_file, 128, 128, img_transforms=None)
print(len(dataset_train))

100%|██████████| 1000/1000 [00:00<00:00, 10385.05it/s]

1000





In [18]:
model_ft = get_instance_segmentation_model(num_classes)
model_ft.to(device)

data_loader = torch.utils.data.DataLoader(
    dataset_train, batch_size=4, shuffle=True, num_workers=8,
    collate_fn=utils.collate_fn)

In [19]:
params = [p for p in model_ft.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=5,
                                               gamma=0.1)
num_epochs = 8
for epoch in range(num_epochs):
    train_one_epoch(model_ft, optimizer, data_loader, device, epoch, print_freq=10)
    lr_scheduler.step()

torch.save(model_ft.state_dict(), "model.bin")

tensor([[[0.3765, 0.3765, 0.3686,  ..., 0.0039, 0.0000, 0.0000],
         [0.3725, 0.3725, 0.3725,  ..., 0.0039, 0.0000, 0.0000],
         [0.3686, 0.3725, 0.3725,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0588, 0.1804, 0.2078,  ..., 0.1176, 0.1216, 0.1137],
         [0.2000, 0.2902, 0.3294,  ..., 0.1804, 0.1765, 0.1216],
         [0.2314, 0.2549, 0.3412,  ..., 0.2431, 0.1765, 0.1373]],

        [[0.3961, 0.3922, 0.3922,  ..., 0.0353, 0.0078, 0.0000],
         [0.3922, 0.3922, 0.3922,  ..., 0.0314, 0.0039, 0.0000],
         [0.3882, 0.3882, 0.3922,  ..., 0.0157, 0.0000, 0.0000],
         ...,
         [0.0706, 0.2706, 0.3608,  ..., 0.1882, 0.2000, 0.1804],
         [0.2588, 0.3961, 0.4118,  ..., 0.1922, 0.1725, 0.1176],
         [0.3255, 0.3608, 0.3922,  ..., 0.2314, 0.1569, 0.1098]],

        [[0.3569, 0.3529, 0.3569,  ..., 0.0039, 0.0000, 0.0000],
         [0.3529, 0.3529, 0.3529,  ..., 0.0039, 0.0000, 0.0000],
         [0.3490, 0.3451, 0.3451,  ..., 0.0000, 0.0000, 0.

        [0.5250, 0.4600, 0.6950, 0.8367]]), 'labels': tensor([ 67, 357, 394]), 'image_id': tensor([190]), 'area': tensor([0.2184, 0.1946, 0.0640])}
tensor([[[0.9294, 0.9294, 0.9294,  ..., 0.9137, 0.9176, 0.9255],
         [0.9255, 0.9255, 0.9255,  ..., 0.9176, 0.9216, 0.9255],
         [0.9216, 0.9216, 0.9216,  ..., 0.9255, 0.9255, 0.9216],
         ...,
         [0.3412, 0.3569, 0.3686,  ..., 0.4824, 0.4902, 0.4510],
         [0.3569, 0.3765, 0.3961,  ..., 0.4627, 0.4706, 0.4863],
         [0.3765, 0.4078, 0.4235,  ..., 0.4235, 0.4392, 0.4941]],

        [[0.9294, 0.9294, 0.9294,  ..., 0.9137, 0.9216, 0.9294],
         [0.9255, 0.9255, 0.9255,  ..., 0.9176, 0.9255, 0.9255],
         [0.9216, 0.9216, 0.9216,  ..., 0.9294, 0.9255, 0.9216],
         ...,
         [0.4471, 0.4784, 0.4431,  ..., 0.4275, 0.4824, 0.4667],
         [0.4941, 0.5020, 0.4863,  ..., 0.4941, 0.5020, 0.5216],
         [0.4980, 0.5255, 0.5373,  ..., 0.5529, 0.5569, 0.5922]],

        [[0.9216, 0.9255, 0.9216,  ..., 

        2.4885e-02])}
tensor([[[0.0824, 0.0706, 0.0549,  ..., 0.0980, 0.0510, 0.0275],
         [0.0902, 0.0667, 0.0706,  ..., 0.1176, 0.0941, 0.0471],
         [0.0902, 0.0784, 0.0784,  ..., 0.1059, 0.1059, 0.0745],
         ...,
         [0.2275, 0.2196, 0.2314,  ..., 0.2549, 0.2627, 0.2588],
         [0.1843, 0.1922, 0.2078,  ..., 0.3294, 0.2863, 0.2902],
         [0.1451, 0.1490, 0.1529,  ..., 0.2667, 0.2431, 0.2510]],

        [[0.1176, 0.1176, 0.1020,  ..., 0.1294, 0.0706, 0.0353],
         [0.1294, 0.1137, 0.1137,  ..., 0.1451, 0.1216, 0.0667],
         [0.1373, 0.1216, 0.1255,  ..., 0.1333, 0.1373, 0.0980],
         ...,
         [0.0431, 0.0392, 0.0392,  ..., 0.0627, 0.0627, 0.0627],
         [0.0392, 0.0392, 0.0431,  ..., 0.1647, 0.1137, 0.1098],
         [0.0353, 0.0392, 0.0471,  ..., 0.1412, 0.1216, 0.1294]],

        [[0.1922, 0.2353, 0.2000,  ..., 0.1373, 0.0706, 0.0353],
         [0.2000, 0.2000, 0.2196,  ..., 0.1608, 0.1294, 0.0627],
         [0.2353, 0.2039, 0.2431,  .

        0.0013])}
tensor([[[0.6078, 0.6078, 0.6078,  ..., 0.6275, 0.6235, 0.6196],
         [0.6118, 0.6118, 0.6118,  ..., 0.6235, 0.6235, 0.6275],
         [0.6157, 0.6157, 0.6157,  ..., 0.6235, 0.6235, 0.6275],
         ...,
         [0.1686, 0.1569, 0.1569,  ..., 0.2980, 0.2824, 0.1922],
         [0.2902, 0.2902, 0.2902,  ..., 0.2196, 0.2039, 0.1294],
         [0.3137, 0.3137, 0.3294,  ..., 0.4353, 0.4000, 0.2980]],

        [[0.8039, 0.8039, 0.8039,  ..., 0.8039, 0.8078, 0.8039],
         [0.8078, 0.8078, 0.8078,  ..., 0.8039, 0.8078, 0.8118],
         [0.8078, 0.8078, 0.8078,  ..., 0.8118, 0.8078, 0.8118],
         ...,
         [0.2431, 0.2314, 0.2235,  ..., 0.2549, 0.2471, 0.1647],
         [0.3843, 0.3882, 0.3765,  ..., 0.2275, 0.2078, 0.1255],
         [0.3882, 0.3922, 0.3961,  ..., 0.4078, 0.3843, 0.2902]],

        [[0.9961, 0.9961, 0.9961,  ..., 0.9922, 0.9804, 0.9882],
         [0.9961, 1.0000, 1.0000,  ..., 0.9961, 0.9882, 0.9882],
         [1.0000, 0.9961, 1.0000,  ..., 

         [0.8196, 0.8039, 0.8275,  ..., 0.5020, 0.6549, 0.9647]]])
{'boxes': tensor([[0.1094, 0.4625, 0.2422, 0.9953],
        [0.2281, 0.8000, 0.4906, 0.9906],
        [0.2578, 0.4625, 0.4094, 0.6984],
        [0.3625, 0.4750, 0.4922, 0.6562],
        [0.4516, 0.4969, 0.5938, 0.7000],
        [0.6531, 0.7703, 0.8375, 0.9953],
        [0.7297, 0.2766, 0.8234, 0.3891],
        [0.0109, 0.0000, 0.9937, 0.2313],
        [0.9328, 0.1672, 0.9953, 0.2438]]), 'labels': tensor([ 67,  67,  67,  67,  67,  67,  67, 463, 463]), 'image_id': tensor([858]), 'area': tensor([0.0708, 0.0500, 0.0358, 0.0235, 0.0289, 0.0415, 0.0105, 0.2273, 0.0048])}
tensor([[[0.6353, 0.6392, 0.6353,  ..., 0.4157, 0.4000, 0.3922],
         [0.5804, 0.6039, 0.6157,  ..., 0.4078, 0.4000, 0.3961],
         [0.5412, 0.5647, 0.5882,  ..., 0.4196, 0.3922, 0.4118],
         ...,
         [0.1569, 0.1490, 0.1647,  ..., 0.8824, 0.9686, 0.9882],
         [0.2627, 0.2118, 0.1843,  ..., 0.8510, 0.9216, 0.9137],
         [0.1569, 0.17

        7.9257e-03, 4.2389e-03])}
tensor([[[0.1608, 0.1725, 0.1882,  ..., 0.8745, 0.9725, 0.9961],
         [0.1608, 0.1765, 0.1922,  ..., 0.8784, 0.9686, 0.9922],
         [0.1686, 0.1765, 0.1961,  ..., 0.8902, 0.9804, 0.9922],
         ...,
         [0.0235, 0.0235, 0.0235,  ..., 0.4784, 0.4745, 0.4392],
         [0.0235, 0.0235, 0.0235,  ..., 0.4549, 0.4471, 0.4471],
         [0.0235, 0.0235, 0.0235,  ..., 0.4510, 0.4353, 0.4549]],

        [[0.1176, 0.1294, 0.1373,  ..., 0.4353, 0.5137, 0.6039],
         [0.1176, 0.1333, 0.1373,  ..., 0.4392, 0.5176, 0.6078],
         [0.1255, 0.1333, 0.1451,  ..., 0.4471, 0.5059, 0.5765],
         ...,
         [0.0235, 0.0235, 0.0235,  ..., 0.2314, 0.2314, 0.2157],
         [0.0235, 0.0235, 0.0235,  ..., 0.2275, 0.2078, 0.2235],
         [0.0235, 0.0235, 0.0235,  ..., 0.2314, 0.2118, 0.2196]],

        [[0.1020, 0.1137, 0.1255,  ..., 0.1686, 0.1843, 0.1922],
         [0.1020, 0.1176, 0.1255,  ..., 0.1647, 0.1843, 0.1922],
         [0.1098, 0.1176

        0.0116, 0.0140, 0.0071, 0.0018, 0.0135, 0.0022])}
tensor([[[0.4784, 0.5451, 0.6196,  ..., 1.0000, 1.0000, 1.0000],
         [0.5333, 0.5294, 0.5373,  ..., 1.0000, 1.0000, 1.0000],
         [0.5333, 0.3686, 0.3490,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [0.3451, 0.2863, 0.3216,  ..., 1.0000, 1.0000, 1.0000],
         [0.3765, 0.3255, 0.2980,  ..., 1.0000, 1.0000, 1.0000],
         [0.3294, 0.3176, 0.2902,  ..., 1.0000, 1.0000, 1.0000]],

        [[0.5765, 0.6549, 0.6863,  ..., 1.0000, 1.0000, 1.0000],
         [0.6353, 0.6353, 0.5804,  ..., 1.0000, 1.0000, 1.0000],
         [0.6510, 0.4471, 0.4039,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [0.3647, 0.2980, 0.3333,  ..., 1.0000, 1.0000, 1.0000],
         [0.3922, 0.3333, 0.2980,  ..., 1.0000, 1.0000, 1.0000],
         [0.3373, 0.3216, 0.2863,  ..., 1.0000, 1.0000, 1.0000]],

        [[0.2510, 0.2784, 0.3255,  ..., 1.0000, 1.0000, 1.0000],
         [0.2824, 0.2824, 0.2902,  ..., 1.0000, 1.0000, 1.0000],


         [0.4706, 0.4902, 0.5098,  ..., 0.5216, 0.5137, 0.5020]]])
{'boxes': tensor([[0.5150, 0.1380, 0.9500, 0.5520],
        [0.1870, 0.5340, 0.6950, 0.6330],
        [0.0490, 0.3990, 0.7820, 0.8900],
        [0.4210, 0.2420, 0.5030, 0.6320],
        [0.0400, 0.3910, 0.7790, 0.9000]]), 'labels': tensor([177, 191, 236, 238, 291]), 'image_id': tensor([256]), 'area': tensor([0.1801, 0.0503, 0.3599, 0.0320, 0.3762])}
tensor([[[0.8314, 0.8275, 0.8157,  ..., 0.5451, 0.5529, 0.5843],
         [0.8314, 0.8196, 0.8078,  ..., 0.6863, 0.6706, 0.6824],
         [0.7647, 0.7569, 0.7569,  ..., 0.7804, 0.7451, 0.7451],
         ...,
         [0.4078, 0.3882, 0.4000,  ..., 0.3294, 0.3176, 0.3059],
         [0.4039, 0.3882, 0.4078,  ..., 0.3216, 0.3020, 0.2863],
         [0.4118, 0.3961, 0.4039,  ..., 0.3255, 0.3098, 0.2863]],

        [[0.8078, 0.8078, 0.8039,  ..., 0.6157, 0.6118, 0.6431],
         [0.8157, 0.8118, 0.8078,  ..., 0.7176, 0.7020, 0.7098],
         [0.7922, 0.7882, 0.7922,  ..., 0.768

         [0.0275, 0.0275, 0.0275,  ..., 0.0235, 0.0235, 0.0235]]])
{'boxes': tensor([[0.4900, 0.0000, 0.9994, 0.9652],
        [0.4906, 0.0000, 0.9994, 0.9746],
        [0.7544, 0.1488, 0.8200, 0.2298]]), 'labels': tensor([388, 456, 456]), 'image_id': tensor([810]), 'area': tensor([0.4916, 0.4958, 0.0053])}
tensor([[[0.4784, 0.4824, 0.4784,  ..., 0.1725, 0.1490, 0.1333],
         [0.4745, 0.4824, 0.4863,  ..., 0.1647, 0.1569, 0.1451],
         [0.4784, 0.4706, 0.4863,  ..., 0.1529, 0.1569, 0.1373],
         ...,
         [0.6745, 0.6784, 0.6706,  ..., 0.8314, 0.8235, 0.8235],
         [0.6549, 0.6588, 0.6471,  ..., 0.8275, 0.8235, 0.8196],
         [0.6392, 0.6275, 0.6039,  ..., 0.8235, 0.8196, 0.8118]],

        [[0.3686, 0.3608, 0.3529,  ..., 0.2745, 0.2549, 0.2392],
         [0.3569, 0.3647, 0.3608,  ..., 0.2627, 0.2549, 0.2471],
         [0.3647, 0.3569, 0.3686,  ..., 0.2549, 0.2549, 0.2392],
         ...,
         [0.7922, 0.7922, 0.7843,  ..., 0.2314, 0.2275, 0.2314],
         [0

         [0.2510, 0.2549, 0.2627,  ..., 0.1882, 0.1608, 0.0863]]])
{'boxes': tensor([[0.4588, 0.9493, 0.4719, 0.9878],
        [0.4756, 0.9184, 0.4913, 0.9512],
        [0.4819, 0.8818, 0.4988, 0.9006],
        [0.5600, 0.7664, 0.5744, 0.8433],
        [0.5719, 0.6623, 0.5781, 0.6782],
        [0.5731, 0.4869, 0.5969, 0.5150],
        [0.5769, 0.7664, 0.5844, 0.7936],
        [0.5800, 0.6689, 0.5900, 0.6942],
        [0.5938, 0.6435, 0.6000, 0.6679],
        [0.6012, 0.4784, 0.6025, 0.4906],
        [0.6087, 0.6435, 0.6169, 0.6642],
        [0.6112, 0.6182, 0.6181, 0.6360],
        [0.6156, 0.5647, 0.6206, 0.5779],
        [0.6194, 0.6257, 0.6244, 0.6379],
        [0.6206, 0.6454, 0.6294, 0.6660],
        [0.6219, 0.5629, 0.6281, 0.5797],
        [0.0000, 0.6379, 0.1544, 0.8227],
        [0.0000, 0.3668, 0.1744, 0.6454],
        [0.0350, 0.8762, 0.1131, 0.9991],
        [0.1800, 0.3068, 0.4156, 0.4822],
        [0.4231, 0.3246, 0.5050, 0.4493],
        [0.4437, 0.4437, 0.4925, 0.5685],

        [0.8225, 0.7491, 0.8625, 0.8839]]), 'labels': tensor([ 81, 388, 397, 406, 406, 406, 406, 406, 406]), 'image_id': tensor([846]), 'area': tensor([0.0012, 0.6733, 0.5086, 0.0033, 0.0145, 0.0085, 0.0050, 0.0033, 0.0054])}
tensor([[[0.6824, 0.3059, 0.1765,  ..., 0.2118, 0.1647, 0.2902],
         [0.8667, 0.6863, 0.3725,  ..., 0.1882, 0.1451, 0.2863],
         [0.8863, 0.8235, 0.6824,  ..., 0.1804, 0.1294, 0.2667],
         ...,
         [0.2314, 0.2235, 0.2196,  ..., 0.2392, 0.2510, 0.2863],
         [0.2353, 0.2510, 0.2667,  ..., 0.2627, 0.2863, 0.2980],
         [0.2353, 0.2471, 0.2745,  ..., 0.3373, 0.3686, 0.2980]],

        [[0.6941, 0.3020, 0.1608,  ..., 0.2275, 0.1137, 0.0745],
         [0.9059, 0.7098, 0.3725,  ..., 0.1961, 0.1020, 0.0627],
         [0.9412, 0.8745, 0.6902,  ..., 0.1765, 0.0980, 0.0667],
         ...,
         [0.2471, 0.2588, 0.2706,  ..., 0.3373, 0.3412, 0.3412],
         [0.2078, 0.2353, 0.2627,  ..., 0.3569, 0.3569, 0.3216],
         [0.1569, 0.1765, 0.2

         [0.3216, 0.3255, 0.3490,  ..., 0.3647, 0.3608, 0.3490]]])
{'boxes': tensor([[0.0950, 0.6831, 0.3560, 0.9984],
        [0.4460, 0.6226, 0.6580, 0.9220],
        [0.4540, 0.4490, 0.5050, 0.5924],
        [0.6620, 0.5159, 0.7820, 0.7707],
        [0.7770, 0.4968, 0.8840, 0.6704],
        [0.0230, 0.1815, 0.0900, 0.5430],
        [0.0780, 0.0000, 0.4750, 0.9984],
        [0.4510, 0.0000, 0.7680, 0.9984],
        [0.6460, 0.0717, 0.8710, 0.9984],
        [0.7790, 0.1290, 0.9910, 0.9268],
        [0.8960, 0.2580, 0.9990, 0.7420],
        [0.0000, 0.6274, 0.0290, 0.6592],
        [0.0240, 0.5048, 0.0390, 0.5350],
        [0.0310, 0.4825, 0.0440, 0.5175],
        [0.0500, 0.4904, 0.0700, 0.5048],
        [0.0680, 0.5080, 0.0870, 0.5318],
        [0.2820, 0.6401, 0.2940, 0.6608],
        [0.4370, 0.6051, 0.4550, 0.6401],
        [0.4570, 0.7389, 0.4770, 0.7803],
        [0.4790, 0.6131, 0.4880, 0.6353],
        [0.7830, 0.8519, 0.8440, 0.9061],
        [0.9120, 0.6895, 0.9380, 0.7309],

        0.0130, 0.0745, 0.0275])}
tensor([[[0.0627, 0.0314, 0.0196,  ..., 0.3451, 0.3294, 0.3216],
         [0.0549, 0.0275, 0.0196,  ..., 0.3373, 0.3216, 0.3137],
         [0.0392, 0.0275, 0.0275,  ..., 0.3333, 0.3216, 0.3098],
         ...,
         [0.2588, 0.4078, 0.5490,  ..., 0.2235, 0.2314, 0.2314],
         [0.2549, 0.4039, 0.5608,  ..., 0.2353, 0.2275, 0.2196],
         [0.2392, 0.3725, 0.5490,  ..., 0.2902, 0.2078, 0.2000]],

        [[0.0471, 0.0235, 0.0196,  ..., 0.2980, 0.2824, 0.2745],
         [0.0392, 0.0235, 0.0196,  ..., 0.2902, 0.2745, 0.2667],
         [0.0314, 0.0196, 0.0196,  ..., 0.2863, 0.2745, 0.2627],
         ...,
         [0.1961, 0.3020, 0.4275,  ..., 0.1804, 0.1843, 0.1843],
         [0.1843, 0.2824, 0.4275,  ..., 0.1922, 0.1804, 0.1725],
         [0.1804, 0.2667, 0.4275,  ..., 0.2392, 0.1608, 0.1529]],

        [[0.0314, 0.0157, 0.0196,  ..., 0.2392, 0.2275, 0.2196],
         [0.0275, 0.0157, 0.0196,  ..., 0.2314, 0.2196, 0.2118],
         [0.0314, 0.0235

         [0.1216, 0.1216, 0.1216,  ..., 0.1294, 0.1255, 0.1216]]])
{'boxes': tensor([[0.0000, 0.0000, 0.9987, 0.9958],
        [0.3794, 0.2200, 0.4225, 0.3058],
        [0.3988, 0.2958, 0.4094, 0.3167],
        [0.4106, 0.2992, 0.4238, 0.3200],
        [0.4156, 0.2000, 0.4288, 0.2342],
        [0.4238, 0.2992, 0.4338, 0.3200],
        [0.4300, 0.1950, 0.4469, 0.2292],
        [0.4338, 0.3108, 0.4444, 0.3383],
        [0.4456, 0.3217, 0.4519, 0.3458],
        [0.4494, 0.1950, 0.4625, 0.2308],
        [0.4544, 0.3267, 0.4638, 0.3475],
        [0.4638, 0.3250, 0.4737, 0.3475],
        [0.4650, 0.1925, 0.4762, 0.2342],
        [0.4781, 0.3142, 0.5075, 0.3425],
        [0.4794, 0.1925, 0.4881, 0.2292],
        [0.4894, 0.1925, 0.5025, 0.2267],
        [0.5063, 0.1967, 0.5194, 0.2258],
        [0.5138, 0.3183, 0.5206, 0.3383],
        [0.5219, 0.3217, 0.5306, 0.3425],
        [0.5231, 0.1892, 0.5375, 0.2242],
        [0.5331, 0.3233, 0.5450, 0.3458],
        [0.5387, 0.1950, 0.5525, 0.2292],

RuntimeError: CUDA out of memory. Tried to allocate 158.00 MiB (GPU 1; 10.91 GiB total capacity; 9.80 GiB already allocated; 117.00 MiB free; 120.84 MiB cached)