In [1]:
from collections import defaultdict
from tqdm.notebook import tqdm
import cv2
import collections
import os

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

import utils
from utils_fucntions import*

import torch.utils
import torch
import torch.utils.data
from torchvision import transforms
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

IMAGE_SIZE = 256

In [2]:
num_classes = 46 + 1

## Датасет

In [3]:
class FashionDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, df, height, width, transforms=None):
        self.transforms = transforms
        self.image_dir = image_dir
        self.df = df
        self.height = height
        self.width = width
        self.image_info = collections.defaultdict(dict)
        self.df['CategoryId'] = self.df.ClassId.apply(lambda x: str(x).split("_")[0])
        temp_df = self.df.groupby('ImageId')[['EncodedPixels', 'CategoryId']].agg(lambda x: list(x)).reset_index()
        size_df = self.df.groupby('ImageId')[['Height', 'Width']].mean().reset_index()
        temp_df = temp_df.merge(size_df, on='ImageId', how='left')
        for index, row in tqdm(temp_df.iterrows(), total=len(temp_df)):
            image_id = row['ImageId']
            image_path = os.path.join(self.image_dir, image_id)
            self.image_info[index]["image_id"] = image_id
            self.image_info[index]["image_path"] = image_path
            self.image_info[index]["width"] = self.width
            self.image_info[index]["height"] = self.height
            self.image_info[index]["labels"] = row["CategoryId"]
            self.image_info[index]["orig_height"] = row["Height"]
            self.image_info[index]["orig_width"] = row["Width"]
            self.image_info[index]["annotations"] = row["EncodedPixels"]

    def __getitem__(self, idx):
        # load images ad masks
        img_path = self.image_info[idx]["image_path"]
        img = Image.open(img_path+'.jpg').convert("RGB")
        img = img.resize((self.width, self.height), resample=Image.BILINEAR)

        info = self.image_info[idx]
        mask = np.zeros((len(info['annotations']), self.width, self.height), dtype=np.uint8)
        labels = []
        for m, (annotation, label) in enumerate(zip(info['annotations'], info['labels'])):
            sub_mask = rle_decode(annotation, (info['orig_height'], info['orig_width']))
            sub_mask = Image.fromarray(sub_mask)
            sub_mask = sub_mask.resize((self.width, self.height), resample=Image.BILINEAR)
            mask[m, :, :] = sub_mask
            labels.append(int(label) + 1)

        num_objs = len(labels)
        boxes = []
        new_labels = []
        new_masks = []

        for i in range(num_objs):
            try:
                pos = np.where(mask[i, :, :])
                xmin = np.min(pos[1])
                xmax = np.max(pos[1])
                ymin = np.min(pos[0])
                ymax = np.max(pos[0])
                if abs(xmax - xmin) >= 20 and abs(ymax - ymin) >= 20:
                    boxes.append([xmin, ymin, xmax, ymax])
                    new_labels.append(labels[i])
                    new_masks.append(mask[i, :, :])
            except ValueError:
                continue

        if len(new_labels) == 0:
            boxes.append([0, 0, 20, 20])
            new_labels.append(0)
            new_masks.append(mask[0, :, :])

        nmx = np.zeros((len(new_masks), self.width, self.height), dtype=np.uint8)
        for i, n in enumerate(new_masks):
            nmx[i, :, :] = n

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(new_labels, dtype=torch.int64)
        masks = torch.as_tensor(nmx, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
    
        
        if self.transforms is not None:
            img = self.transforms(img)

        return img, target, img_path.split('/')[-1]+'.jpg'

    def __len__(self):
        return len(self.image_info)

In [4]:
transform = transforms.Compose([transforms.ToTensor()])

In [5]:
dataset_train = FashionDataset("D:/project/train/",
                               pd.read_csv('D:/project/train.csv'),
                               IMAGE_SIZE,
                               IMAGE_SIZE,
                               transforms=transform)

train_data_loader = torch.utils.data.DataLoader(
    dataset_train, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

  0%|          | 0/45623 [00:00<?, ?it/s]

## Модель

In [6]:
def get_MaskRCNN_Model(num_classes, device):
    model_ft =torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model_ft.roi_heads.box_predictor.cls_score.in_features
    model_ft.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model_ft.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 512
    model_ft.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    model_ft.to(device)
    for param in model_ft.parameters():
        param.requires_grad = True
        
    return model_ft

model_ft = get_MaskRCNN_Model(num_classes, device)
model_ft.train();

## Тренировка

In [7]:
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq = 100):
    history = []
    scaler = torch.cuda.amp.GradScaler()
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1. / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
        
    for (images, targets, path) in metric_logger.log_every(data_loader, print_freq, header):
        
        with torch.cuda.amp.autocast():
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = model(images, targets)
            loss_dict_reduced = utils.reduce_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            loss_value = losses_reduced.item()
            
            history.append(loss_value)

            scaler.scale(losses_reduced).backward()
            scaler.step(optimizer)
            scaler.update()
        
        if lr_scheduler is not None:
            lr_scheduler.step()
            
        optimizer.zero_grad()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        
    return history

In [8]:
optimizer = torch.optim.Adam(model_ft.parameters(), lr=0.001)

In [9]:
num_epochs = 1
for epoch in tqdm(range(num_epochs)):
    history = train_one_epoch(model_ft, optimizer, train_data_loader, device, epoch)

  0%|          | 0/1 [00:00<?, ?it/s]



Epoch: [0]  [    0/22812]  eta: 21:21:23  lr: 0.000002  loss: 5.6876 (5.6876)  loss_classifier: 3.8810 (3.8810)  loss_box_reg: 0.4412 (0.4412)  loss_mask: 1.2071 (1.2071)  loss_objectness: 0.1307 (0.1307)  loss_rpn_box_reg: 0.0276 (0.0276)  time: 3.3703  data: 0.5984  max mem: 2551
Epoch: [0]  [  100/22812]  eta: 6:36:57  lr: 0.000102  loss: 1.4166 (2.2123)  loss_classifier: 0.5355 (1.0611)  loss_box_reg: 0.3736 (0.3885)  loss_mask: 0.4871 (0.6791)  loss_objectness: 0.0183 (0.0651)  loss_rpn_box_reg: 0.0087 (0.0185)  time: 1.1116  data: 0.7231  max mem: 3396


KeyboardInterrupt: 

In [16]:
torch.save(model_ft.state_dict(), f'adam_warmup_pytorch_mrcnn_{IMAGE_SIZE}')

In [None]:
model_ft.eval();