In [None]:
import os
import numpy as np
import pandas as pd
import torch
from PIL import Image

In [None]:
df = pd.read_csv('drive/MyDrive/MarsVisionProject/HiRiseData/combined_output.csv')
df = df.set_index('observation_id')

In [None]:
df.head()

Unnamed: 0_level_0,impact_id,lat,lon,diameter_in_m,xmin,xmax,ymin,ymax,xmin_px,xmax_px,ymin_px,ymax_px,bb_width_px,bb_height_px,date_discovered,image_path,class
observation_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
ESP_058989_2050,951,24.691,265.975,4.5,0.033203,0.423828,0.602372,0.641897,17,217,3048,3248,100,100,2014-06-21,od_data/positive/raw/ESP_058989_2050.png,1
ESP_058991_1670,952,-12.681,217.286,11.6,0.009766,0.791016,0.459881,0.554743,5,405,2327,2807,240,240,2018-04-15,od_data/positive/raw/ESP_058991_1670.png,1
ESP_059004_1780,953,-2.19,220.95,3.0,0.185547,0.419922,0.447676,0.485615,95,215,1416,1536,60,60,2018-05-03,od_data/positive/raw/ESP_059004_1780.png,1
ESP_059030_2150,954,34.899,226.788,6.3,0.248047,0.716797,0.539723,0.587154,127,367,2731,2971,120,120,2018-04-02,od_data/positive/raw/ESP_059030_2150.png,1
ESP_059043_1980,955,17.716,233.726,3.6,0.009766,0.228516,0.509618,0.551779,5,117,1934,2094,80,80,2015-01-24,od_data/positive/raw/ESP_059043_1980.png,1


In [None]:
root = 'drive/MyDrive/MarsVisionProject/HiRiseData/raw'
img_path = os.path.join('drive/MyDrive/MarsVisionProject/HiRiseData/raw', df['observation_id'][0] + '.png')
img = Image.open(img_path).convert("RGB")

In [None]:
imgs = list(os.listdir(os.path.join('drive/MyDrive/MarsVisionProject/HiRiseData/raw')))

In [None]:
len(df.index)

1493

In [None]:
df.index[0]

'ESP_058989_2050'

In [None]:
class CraterDetectionDataset(torch.utils.data.Dataset):
    def __init__(self, root, csv_filename, transforms):
        self.root = root
        self.transforms = transforms

        self.df = pd.read_csv(csv_filename)
        self.df = self.df.set_index('observation_id')
        self.imgs = list(os.listdir(os.path.join('drive/MyDrive/MarsVisionProject/HiRiseData/raw')))


    def __getitem__(self, idx):

        img_path = os.path.join(self.root, self.df.index[idx] + '.png')
        img = Image.open(img_path).convert("RGB")



        # get bounding box coordinates for each image
        num_objs = len(self.df.index)
        boxes = []
        for i in range(num_objs):

            if self.df['class'][i] == 1:
                xmin = self.df.loc[self.df.index[i]]['xmin']
                xmax = self.df.loc[self.df.index[i]]['xmax']
                ymin = self.df.loc[self.df.index[i]]['ymin']
                ymax = self.df.loc[self.df.index[i]]['ymax']
                boxes.append([xmin, ymin, xmax, ymax])
            else:
                boxes.append([0,0,0,0])
        
        # labels
        labels = []
        for i in range(num_objs):
            
            if self.df['class'][i] == 1:
                labels.append(1)
            else:
                labels.append(0)

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int8)

       
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

In [None]:
from torchvision import transforms as T

def get_transform(train):
    transforms = []
    transforms.append(T.PILToTensor())
    #transforms.append(T.ConvertImageDtype(torch.float))
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [None]:
root = 'drive/MyDrive/MarsVisionProject/HiRiseData/raw'
csv_filename = 'drive/MyDrive/MarsVisionProject/HiRiseData/combined_output.csv'

dataset = CraterDetectionDataset(root, csv_filename, get_transform(train=True))

In [None]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# load a model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # positive and negative
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

In [None]:
import torch.utils.data

data_loader = torch.utils.data.DataLoader(
 dataset, batch_size=8, shuffle=True, num_workers=4)
# For Training
images,targets = next(iter(data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
output = model(images,targets)   # Returns losses and detections
# For inference
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)           # Returns predictions

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc2086ae670>Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc2086ae670>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    
self._shutdown_workers()Traceback (most recent call last):

Exception ignored in:   File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fc2086ae670>        
if w.is_alive():self._shutdown_workers()Traceback (most recent call last):

  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive

  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    assert self._parent_pid == os.getpid(), 'can o

TypeError: ignored