In [None]:
import torch
import torchvision.models as models
from torchvision.datasets.vision import VisionDataset
import torchvision.transforms as transforms
import argparse
import os
import cv2
import numpy as np
from torch.utils.data import DataLoader
import collections
from xml.etree.ElementTree import Element as ET_Element
from xml.etree.ElementTree import parse as ET_parse
from typing import Any, Callable, Dict, List, Optional, Tuple
from PIL import Image
from google.colab.patches import cv2_imshow
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from torchvision.utils import make_grid
from PIL import ImageDraw, ImageFont
import torchvision.transforms.functional as F
from torchsummary import summary


: 

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
global writer
global logs

In [None]:
def unnormalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    """
    Unnormalize a tensor image with mean and standard deviation.
    Input tensor should be of shape CxHxW.
    mean and std are sequences of means and standard deviations per channel.
    """
    # Duplicate the mean and std to match the tensor's shape
    mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
    std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)

    # Apply the unnormalize formula
    unnormalized_tensor = tensor * std + mean

    return unnormalized_tensor

In [None]:
def display_bbox_eval(image,boxes):
    image = np.transpose(image, (1,2,0))
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    # image = np.zeros((540, 960, 3), dtype=np.uint8)
    image_255 = np.clip(image * 255, 0, 255).astype(np.uint8)
    # print(image_255)
    # print("Shape: " + str(image_255.shape))
    for bbox in boxes:
        print("BBox: " + str(bbox))
        xmin = bbox[0].item()
        ymin = bbox[1].item()
        xmax = bbox[2].item()
        ymax = bbox[3].item()
        pt1 = (int(xmin),int(ymin))
        pt2 = (int(xmax),int(ymax))
        print("Pt1: " + str(pt1))
        print("Pt2: " + str(pt2))
        cv2.rectangle(image_255, pt1, pt2, (0,255,0),2)
    cv2.startWindowThread()
    cv2_imshow(image_255)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

In [None]:
def display_bbox_train(image,boxes):
    image = np.transpose(image, (1,2,0))
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    image_255 = np.clip(image * 255, 0, 255).astype(np.uint8)
    for bbox in boxes:
        print("BBox: " + str(bbox))
        xmin = bbox[0]
        ymin = bbox[1]
        xmax = bbox[2]
        ymax = bbox[3]
        pt1 = (int(xmin),int(ymin))
        pt2 = (int(xmax),int(ymax))
        print("Pt1: " + str(pt1))
        print("Pt2: " + str(pt2))
        cv2.rectangle(image_255, pt1, pt2, (0,255,0),2)
    cv2.startWindowThread()
    cv2_imshow(image_255)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

In [None]:
def display_bbox(dataset):
    for data in dataset:
        # Convert the PIL.Image.Image to a NumPy array
        image, target = data
        print(image.shape)
        print(image_array.shape)
        image_array = np.array(image)
        image_array = np.transpose(image_array, (1,2,0))
        image_array = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
        print("Display Target: " + str(target))
        for o in target["boxes"]:
            xmin = o[0].item()
            ymin = o[1].item()
            xmax = o[2].item()
            ymax = o[3].item()
            pt1 = (xmin,ymin)
            pt2 = (xmax,ymax)
            cv2.rectangle(image_array, pt1, pt2, (0,255,0),2)
        cv2_imshow('Display window', image_array)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

In [None]:
# Function to check if Resize is in the transform
def contains_resize(transform):
    for t in transform.transforms:
        if isinstance(t, transforms.Resize):
            return t
    return None

In [None]:

def my_collate(batch):
    data = [i[0] for i in batch] 
    target = [i[1] for i in batch]
    return data,target 

In [None]:
class TrashDataset(VisionDataset):
    def __init__(
        self,
        root: str,
        image_set: str = "train",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
    ):
        super().__init__(root, transforms, transform, target_transform)
        voc_root = os.path.join(self.root,os.path.join("VOCdevkit", "VOC2012"))
        splits_dir = os.path.join(voc_root, "ImageSets", "Main")
        split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
        with open(os.path.join(split_f)) as f:
            file_names = [x.strip() for x in f.readlines()]
        image_dir = os.path.join(voc_root, "JPEGImages")
        self.images = [os.path.join(image_dir, x + ".jpeg") for x in file_names]
        target_dir = os.path.join(voc_root, "Annotations")
        self.targets = [os.path.join(target_dir, x + ".xml") for x in file_names]
        self.imagesize = (960.0,540.0)
        assert len(self.images) == len(self.targets)
    def __len__(self) -> int:
        return len(self.images)
    @property
    def annotations(self) -> List[str]:
        return self.targets
    #Modified so it returns the acutal target/label format needed for training
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img = Image.open(self.images[index]).convert("RGB")
        target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
        objs = target["annotation"]["object"]
        #Temp just set it to 0?
        labels = [1] * len(objs)
        boxes = [[int(bbox["bndbox"]["xmin"]), int(bbox["bndbox"]["ymin"]), int(bbox["bndbox"]["xmax"]), int(bbox["bndbox"]["ymax"])] for bbox in objs]
        target = {"boxes": torch.tensor(boxes), "labels": torch.tensor(labels)}
        # print("Before Targets: " + str(target["boxes"]))
        # if self.transforms is not None:
        #     img, target = self.transforms(img, target)
        if self.transform is not None:
            img, target = self.transforms(img,target)
        resize = contains_resize(self.transform)
        if resize != None:
          size_x = resize.size[0]
          size_y = resize.size[1]
          # Resize the bounding boxes accordingly
          for b in target["boxes"]:
            b[0] = (b[0] / self.imagesize[0]) * size_x
            b[1] = (b[1] / self.imagesize[1]) * size_y
            b[2] = (b[2] / self.imagesize[0]) * size_x
            b[3] = (b[3] / self.imagesize[1]) * size_y
        # print("After Targets: " + str(target["boxes"]))
        #   # Display to make sure the bounding boxes were scaled correctly
        # mean = [0.485, 0.456, 0.406]
        # std = [0.229, 0.224, 0.225]
        # unnormalized_image = unnormalize(img,mean,std)
        # display_bbox_train(unnormalized_image.numpy(),target["boxes"])
        return img, target

    @staticmethod
    def parse_voc_xml(node: ET_Element) -> Dict[str, Any]:
        voc_dict: Dict[str, Any] = {}
        children = list(node)
        if children:
            def_dic: Dict[str, Any] = collections.defaultdict(list)
            for dc in map(TrashDataset.parse_voc_xml, children):
                for ind, v in dc.items():
                    def_dic[ind].append(v)
            if node.tag == "annotation":
                def_dic["object"] = [def_dic["object"]]
            voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}}
        if node.text:
            text = node.text.strip()
            if not children:
                voc_dict[node.tag] = text
        return voc_dict

In [None]:
def eval_model(model, dataloader):
    for images, targets in dataloader:
        images = [image.to(device) for image in images]  # Move images to device
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]  # Move targets to device

        outputs = model(images, targets)
        print("Outputs: " + str(outputs))
        scores = outputs[0]['scores']
        boxes = outputs[0]['boxes']
        # Filter scores greater than 0.87
        high_score_indices = 0
        # Select boxes with scores above 0.87
        high_score_boxes = boxes[high_score_indices]
        #Only display first image?
        unnormalized_image = unnormalize(images[0])
        display_bbox_eval(unnormalized_image.numpy(),[high_score_boxes])
        print(high_score_boxes)


In [None]:
#Write the iamges with bounding boxes to
def write_image(model,images,targets,threshold,global_step,title):
  model.eval()
  annotated_im = []
  outputs = model(images,targets)
  for im,out in zip(images,outputs):
    image_pil = F.to_pil_image(unnormalize(im))  # Convert to PIL Image
    draw = ImageDraw.Draw(image_pil)
    scores = out['scores']
    boxes = out['boxes']
    # print("Scores: " + str(scores))
    #Always draw the top 2 boxes
    high_score_indices = scores > threshold
    high_score_boxes = boxes[:2]
    high_scores = scores[:2]
    for sc, bbox in zip(high_scores,high_score_boxes):
        xmin = bbox[0].item()
        ymin = bbox[1].item()
        xmax = bbox[2].item()
        ymax = bbox[3].item()
        pt1 = (int(xmin),int(ymin))
        pt2 = (int(xmax),int(ymax))
        draw.rectangle([pt1, pt2], outline="green")
        draw.text(pt1,'{0:.2f}'.format(sc.item()),fill=(0,255,0))
    annotated_im.append(F.to_tensor(image_pil))
  img_grid = make_grid(annotated_im)
  writer.add_image(title + "-annotated_images", img_grid, global_step=global_step)
  model.train()

In [None]:
def train_model(model,train_loader,optimizer,num_epochs,params,weights_path):
    log_interval = 1
    image_log_interval = 5
    global_step = 0
    title = "Loss/train/" + str(params['lr']) + '/' + str(params['batch_size']) + '/' + str(params['num_epochs']) + '/'
    now = datetime.now()
    date_time_str = now.strftime("%Y-%m-%d_%H-%M-%S")
    title_with_datetime = title + date_time_str
    for epoch in range(0,num_epochs):
        for batch_idx,(images, targets) in enumerate(train_loader):
            images = [image.to(device) for image in images]  # Move images to device
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]  # Move targets to device
            outputs = model(images, targets)
            losses = sum(loss for loss in outputs.values())
            if batch_idx % log_interval == 0:
              print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch, batch_idx * len(images), len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), losses.item()))
              #Maybe only display images at a lesser frequency? Otherwise System RAM kidna takes a beating.
              writer.add_scalar(title_with_datetime, losses.item(), global_step)
            if batch_idx % image_log_interval == 0:
              display_images = images[:16]
              display_targets = targets[:16]
              write_image(model,display_images,display_targets,0.87,global_step,title_with_datetime)
              torch.cuda.empty_cache() #empty cache NOTE: this might be a big performance penalty I guess we'll c
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            global_step += 1
        weights_title = "/weights" + str(params['lr']) + '_' + str(params['batch_size']) + '_' + str(params['num_epochs']) + '_' + str(epoch)
        torch.save(model.state_dict(),weights_path + weights_title)
        print("Saved weights - Epoch: {} in {}".format(epoch,weights_path + weights_title))


In [None]:
#Parameters Flags
path = 'rgb_data/'
save = 'weights/'
train_flag = True
num_epochs = 30
batch_size = 64
lr = 2.5e-2
params = {'lr': lr, 'batch_size': batch_size, 'num_epochs': num_epochs}

In [None]:
%load_ext tensorboard
drive_path = '/content/drive/MyDrive/trash-model/classifier'
data_path = os.path.join(drive_path, path)
logs = drive_path + '/logs/tensorboard'
writer = SummaryWriter(logs)
%tensorboard --logdir /content/drive/MyDrive/trash-model/classifier/logs/tensorboard

mobile_netv3 = models.detection.ssdlite.ssdlite320_mobilenet_v3_large(weights_backbone = "DEFAULT", num_classes = 2,progress = True)
# mobile_netv3 = models.detection.ssdlite.ssdlite320_mobilenet_v3_large(pretrained = True, num_classes = 2)

mobile_netv3.to(device)  # Move model to the selected device
#Weights_backbone = pretrained=True,

#Use transforms. Should just be tensors?
image_transforms = transforms.Compose([
    transforms.ColorJitter(brightness = (0.5,1.5), contrast = (0.5,1.5), saturation = (0.5,1.5), hue = (-0.1, 0.1)),
    transforms.Resize((320,320)),
    transforms.ToTensor(),  # Convert the image to a PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
#TODO: SSDHEadClassificaiton()
for param in mobile_netv3.parameters():
    param.requires_grad = False
for head_param in mobile_netv3.head.parameters():
    head_param.requires_grad = True

# print("Datapath: " + str(data_path))
#TODO: save weights for each epoch
weights_string = "trash_weights.pth"
weights_title = os.path.join(os.path.join(drive_path,save),weights_string)
weights_path = os.path.join(drive_path,save)

if train_flag:
    mobile_netv3.train()
    dataset = TrashDataset(root = data_path, image_set='train', transform = image_transforms)
    train_dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = 2, collate_fn = my_collate)
    optimizer = torch.optim.SGD(mobile_netv3.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
    train_model(mobile_netv3,train_dataloader, optimizer,num_epochs, params,weights_path)
    writer.flush()
else:
    dataset = TrashDataset(root = data_path, image_set='val', transform = image_transforms)
    eval_dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = 2, collate_fn = my_collate)
    mobile_netv3.load_state_dict(torch.load(weights_title, map_location=device))
    mobile_netv3.eval()
    eval_model(mobile_netv3,eval_dataloader)