In [1]:
import os
import pandas as pd
import torch
import torchvision
import matplotlib.pyplot as plt
%matplotlib auto

Using matplotlib backend: Qt5Agg


In [25]:
def read_data_bananas(is_train=True):
    data_dir=r'F:\study\ml\banana-detection'
    csv_fname=os.path.join(data_dir,'bananas_train' if is_train else 'bananas_val','label.csv')
    csv_data=pd.read_csv(csv_fname)
    csv_data=csv_data.set_index('img_name')
    images,targets=[],[]
    for img_name,target in csv_data.iterrows():
        images.append(torchvision.io.read_image(os.path.join(data_dir,'bananas_train' if is_train else 'bananas_val','images',f'{img_name}')))
        targets.append(list(target))
    return images,torch.tensor(targets).unsqueeze(1)/256

In [20]:
is_train=True
data_dir=r'F:\study\ml\banana-detection'
csv_fname=os.path.join(data_dir,'bananas_train' if is_train else 'bananas_val','label.csv')
csv_data=pd.read_csv(csv_fname)
csv_data=csv_data.set_index('img_name')

In [23]:
csv_data.head()

Unnamed: 0_level_0,label,xmin,ymin,xmax,ymax
img_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0.png,0,104,20,143,58
1.png,0,68,175,118,223
2.png,0,163,173,218,239
3.png,0,48,157,84,201
4.png,0,32,34,90,86


In [22]:
for i,j in csv_data.head().iterrows():
    print(i)
    print('--------------')
    print(j)
    
    

0.png
--------------
label      0
xmin     104
ymin      20
xmax     143
ymax      58
Name: 0.png, dtype: int64
1.png
--------------
label      0
xmin      68
ymin     175
xmax     118
ymax     223
Name: 1.png, dtype: int64
2.png
--------------
label      0
xmin     163
ymin     173
xmax     218
ymax     239
Name: 2.png, dtype: int64
3.png
--------------
label      0
xmin      48
ymin     157
xmax      84
ymax     201
Name: 3.png, dtype: int64
4.png
--------------
label     0
xmin     32
ymin     34
xmax     90
ymax     86
Name: 4.png, dtype: int64


In [26]:
fs,ls=read_data_bananas()

In [28]:
fs[:5][0].shape

torch.Size([3, 256, 256])

In [29]:
ls[:5]

tensor([[[0.0000, 0.4062, 0.0781, 0.5586, 0.2266]],

        [[0.0000, 0.2656, 0.6836, 0.4609, 0.8711]],

        [[0.0000, 0.6367, 0.6758, 0.8516, 0.9336]],

        [[0.0000, 0.1875, 0.6133, 0.3281, 0.7852]],

        [[0.0000, 0.1250, 0.1328, 0.3516, 0.3359]]])

In [30]:
class BananasDataset(torch.utils.data.Dataset):
    def __init__(self,is_train):
        self.features,self.labels=read_data_bananas(is_train)
        print('read ',len(self.features),(f' training examples ' if is_train else f' validation examples'))
        
    def __getitem__(self,idx):
        return (self.features[idx].float(),self.labels[idx])
    
    def __len__(self):
        return len(self.features)

In [31]:
def load_data_bananas(batch_size):
    train_iter=torch.utils.data.DataLoader(BananasDataset(is_train=True),batch_size,shuffle=True)
    val_iter=torch.utils.data.DataLoader(BananasDataset(is_train=False),batch_size)
    return train_iter,val_iter

In [32]:
batch_size,edge_size=32,256
train_iter,_=load_data_bananas(batch_size)
batch_size=next(iter(train_iter))
batch_size[0].shape,batch_size[1].shape

read  1000  training examples 
read  100  validation examples


(torch.Size([32, 3, 256, 256]), torch.Size([32, 1, 5]))

In [59]:
def bbox_to_rect(bbox,color):
    return plt.Rectangle(xy=(bbox[0],bbox[1]),width=bbox[2]-bbox[0],height=bbox[3]-bbox[1],fill=False,edgecolor=color,linewidth=2)

In [67]:
def show_bboxes(axes,bboxes,labels=None,colors=None):
    def _make_list(obj,default_values=None):
        if obj is None:
            obj=default_values
        elif not isinstance(obj,(list,tuple)):
            obj=[obj]
        return obj
    
    labels=_make_list(labels)
    colors=_make_list(colors,['b','g','r','m','c'])
#     plt.imshow(axes)
    for i ,bbox in enumerate(bboxes):
        color=colors[i % len(colors)]
        rect=bbox_to_rect(bbox.detach().numpy(),color)
        axes.add_patch(rect)
        if labels and len(labels) > i:
            text_color='k' if color=='w' else 'w'
            axes.text(rect.xy[0],rect.xy[1],labels[i],va='center',ha='center',fontsize=9,color=text_color,bbox=dict(facecolor=color,lw=0))
    

In [68]:
def show_images(images,num_rows,num_cols,titles=None,scale=1.5):
    figsize=(num_cols*scale,num_rows*scale)
    _,axes=plt.subplots(num_rows,num_cols,figsize=figsize)
    axes=axes.flatten()
    for i,(ax,img) in enumerate(zip(axes,images)):
#         plt.subplot(num_rows,num_cols,i+1)
        if torch.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.title(titles[i])
    return axes

In [69]:
imgs=(batch_size[0][0:10].permute(0,2,3,1))/255
axes=show_images(imgs,2,5,scale=2)
for ax,label in zip(axes,batch_size[1][0:10]):
    show_bboxes(ax,[label[0][1:5]*edge_size],colors=['w'])
    