In [1]:
import data.dataset as dataset

from models.model import HybridNet
from utils.anchors import Anchor
from utils.labels import get_detection_labels
import torch
import torchvision
import numpy as np
import tqdm
import cv2


In [2]:
detection_class_num = 8

detection_datset = dataset.DetectionDataset()
segment_dataset = dataset.SegmentDataset()
anchor_generator = Anchor((375,1242))


In [3]:
img_transform = torchvision.transforms.Compose(
    [torchvision.transforms.Resize(512, interpolation=torchvision.transforms.InterpolationMode.NEAREST)]
)

def detection_rescale(X):
    img_batch = []
    label_batch = []
    for img, label in X:
        c, h, w = img.size()
        scale_h = h / 375
        scale_w = h / 1242
        scale = torch.tensor([1, scale_w, scale_h, scale_w, scale_h])
        label = label * scale
        img = torchvision.transforms.Resize((375,1242))(img)
        img_batch.append(img)
        labels = []
        for anchor in anchor_generator.anchors_list[-6:]:
            label_ = get_detection_labels(label, anchor)
            labels.append(label_)

        labels = torch.concat(labels, dim=0)
        label_batch.append(labels)
    return torch.stack(img_batch,dim=0), torch.stack(label_batch,dim=0)


def segmentation_rescale(X):
    img_batch = []
    label_batch = []
    for img, label in X:
        img = img_transform(img)
        label = img_transform(label)
        img_batch.append(img)
        label_batch.append(label)
    return torch.stack(img_batch,dim=0), torch.stack(label_batch,dim=0)

detection_dataloader = torch.utils.data.DataLoader(detection_datset, batch_size=8, shuffle=True, collate_fn = detection_rescale)
segment_dataloader = torch.utils.data.DataLoader(segment_dataset, batch_size=8, shuffle=True, collate_fn = segmentation_rescale)

torch.save(detection_dataloader, "detection_dataset.pt")
torch.save(segment_dataloader, "segment_dataset.pt")


In [4]:
detection_dataloader = torch.load("detection_dataset.pt")
segment_dataloader = torch.load("segment_dataset.pt")

In [5]:
net = HybridNet(128, anchor_generator.anchor_nums, detection_class_num)
net = net.to('cuda')

In [6]:
# # run training
# import tqdm
# import gc
# loss_fn = torch.nn.BCELoss()
# optimizer = torch.optim.Adam(net.parameters())

# torch.cuda.empty_cache()
# gc.collect()

# for i in range(5):
#     pbar = tqdm.tqdm(total = len(segment_dataloader))
#     accum = [0] * 2
#     for batch in segment_dataloader:
#         X, y = batch
#         y = torch.where(y==7, 1, 0)
#         y = y.float()
#         X = X.to('cuda')
#         X = X / 255.0
#         pred = net(X)
#         pred = pred[0].to('cpu')
#         loss = loss_fn(pred, y)
#         pred = torch.where(pred > 0.5, 1, 0)
#         acc = (pred == y).to(torch.uint8)
#         accum[0] = accum[0] + torch.sum(acc)
#         accum[1] = accum[1] + (acc.size()[0]*acc.size()[1]*acc.size()[2]*acc.size()[3])
#         acc = accum[0] / accum[1]
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         pbar.set_description(f"loss:{loss:.3f}, acc:{acc*100:.3f}%")
#         pbar.update()
#     pbar.close()

In [None]:
from loss.detection_loss import det_loss
loss = det_loss(detection_class_num)
optimizer = torch.optim.Adam(net.parameters())

for i in range(10):
    pbar = tqdm.tqdm(total = len(detection_dataloader))
    accum = [0] * 2
    for batch in detection_dataloader:
        X, y = batch
        X = X.to('cuda')
        X = X / 255.0
        pred = net(X)
        score = loss(pred[1], y)
        optimizer.zero_grad()
        score.backward()
        optimizer.step()
        pbar.set_description(f"loss:{score:.3f}")
        pbar.update()
    pbar.close()

loss:1.840: 100%|██████████| 936/936 [18:52<00:00,  1.21s/it]
loss:1.847: 100%|██████████| 936/936 [15:24<00:00,  1.01it/s]
loss:1.665: 100%|██████████| 936/936 [15:24<00:00,  1.01it/s]
loss:1.379: 100%|██████████| 936/936 [15:22<00:00,  1.01it/s]
loss:1.469: 100%|██████████| 936/936 [15:21<00:00,  1.02it/s]
loss:1.108: 100%|██████████| 936/936 [15:23<00:00,  1.01it/s]
loss:0.916: 100%|██████████| 936/936 [15:21<00:00,  1.02it/s]
loss:1.748: 100%|██████████| 936/936 [15:22<00:00,  1.01it/s]
loss:0.895: 100%|██████████| 936/936 [15:22<00:00,  1.02it/s]
loss:1.515: 100%|██████████| 936/936 [15:28<00:00,  1.01it/s]
loss:0.802: 100%|██████████| 936/936 [15:25<00:00,  1.01it/s]
loss:0.577: 100%|██████████| 936/936 [15:27<00:00,  1.01it/s]
loss:nan: 100%|██████████| 936/936 [15:37<00:00,  1.00s/it]  
loss:nan: 100%|██████████| 936/936 [15:33<00:00,  1.00it/s]
loss:nan: 100%|██████████| 936/936 [15:29<00:00,  1.01it/s]
loss:nan: 100%|██████████| 936/936 [15:34<00:00,  1.00it/s]
loss:nan: 100%

In [None]:

torch.save(detection_dataloader, 'detection_dataset.pt')
torch.save(segment_dataloader, 'segmentation_dataset.pt')

net = torch.save(net, 'trained_model.pth')
net = torch.load("trained_model.pth")

import matplotlib.pyplot as plt
sample = next(iter(segment_dataloader))
sample = sample[0]
sample = sample.to('cuda')
sample = sample.to(torch.float32)
output = net(sample)
output = output[1]
output = output.to('cpu')
output = torch.permute(output, (0,2,3,1))
output = output.detach().numpy()

fig, axes = plt.subplots(3,3)
fig.set_size_inches(16,16)

sample = sample.to('cpu').to(torch.uint8)
sample = torch.permute(sample, (0,2,3,1))
for i in range(8):
    axes[i//3, i%3].imshow(sample[i])

fig, axes = plt.subplots(3,3)
for i in range(8):
    axes[i//3, i%3].imshow(output[i])



AttributeError: 'list' object has no attribute 'to'