In [1]:
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [2]:
import time
import os
from os import listdir
from os.path import isfile, join

import pprint
import json
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image, ImageDraw, ImageColor, ImageFont
from tqdm.auto import tqdm

import torch
import torchvision
from torchvision.utils import make_grid
from torchvision.io import read_image
import torchvision.transforms.functional as F
from torchvision.utils import draw_segmentation_masks
plt.rcParams["savefig.bbox"] = 'tight'

plt.style.use('seaborn-white')
matplotlib.rcParams['mathtext.fontset'] = 'cm'
matplotlib.rcParams['font.family'] = 'STIXGeneral'

In [3]:
def load_annotation(label_path, image_key):
    with open(join(label_path, '{:s}.json'.format(image_key)), 'r') as fid:
        anno = json.load(fid)
    return anno

In [4]:
data_dir = '/data/shared/mtsd_v2_fully_annotated/'
img_path = join(data_dir, 'train')
label_path = join(data_dir, 'annotations')

filenames = [f for f in listdir(img_path) if isfile(join(img_path, f))]
np.random.shuffle(filenames)

color = 'red'
alpha = 125
try:
    font = ImageFont.truetype('arial.ttf', 15)
except:
    print('Falling back to default font...')
    font = ImageFont.load_default()

images = []
for i, filename in enumerate(filenames):
    img = Image.open(join(img_path, filename))
#     img = np.array(img)
#     bbox = load_annotation(label_path, filenames.split('.')[0])['objects']['bbox']
#     x1, y1 = bbox['xmin'], bbox['ymin']
#     x2, y2 = bbox['xmax'], bbox['ymax']
    anno = load_annotation(label_path, filename.split('.')[0])
    
    img = img.convert('RGBA')
    img_draw = ImageDraw.Draw(img)

    rects = Image.new('RGBA', img.size)
    rects_draw = ImageDraw.Draw(rects)

    for obj in anno['objects']:
        x1 = obj['bbox']['xmin']
        y1 = obj['bbox']['ymin']
        x2 = obj['bbox']['xmax']
        y2 = obj['bbox']['ymax']

        color_tuple = ImageColor.getrgb(color)
        if len(color_tuple) == 3:
            color_tuple = color_tuple + (alpha,)
        else:
            color_tuple[-1] = alpha

        rects_draw.rectangle((x1+1, y1+1, x2-1, y2-1), fill=color_tuple)
        img_draw.line(((x1, y1), (x2, y1), (x2, y2), (x1, y2), (x1, y1)), fill='black', width=1)

        class_name = obj['label']
        img_draw.text((x1 + 5, y1 + 5), class_name, font=font)

    img = Image.alpha_composite(img, rects)
    images.append(img)
    if len(images) == 4:
        break


Falling back to default font...


In [5]:
import pandas as pd 

data = pd.read_csv('/data/shared/mtsd_v2_fully_annotated/traffic_sign_dimension.csv') 
data

Unnamed: 0,label,id,shape,width,height
0,warning--added-lane-right--g1,w4-3,diamond,36.0,36.0
1,warning--bicycles-crossing--g2,w11-1,diamond,36.0,36.0
2,warning--bicycles-crossing--g3,,diamond,36.0,36.0
3,warning--bus-stop-ahead--g3,,diamond,36.0,36.0
4,warning--children--g2,,diamond,36.0,36.0
...,...,...,...,...,...
101,regulatory--turn-left--g2,r3-5,rect,30.0,36.0
102,regulatory--turn-right--g3,r3-5,rect,30.0,36.0
103,regulatory--wrong-way--g1,r5-1a,rect,42.0,30.0
104,warning--school-zone--g2,s1-1,pentagon,36.0,36.0


In [6]:
selected_labels = ['octagon,36,36', 'diamond,36,36', 'pentagon,36,36', 'rect,36,48', 'rect,30,36']

grouped_labels = {}
with open('/data/shared/mtsd_v2_fully_annotated/traffic_sign_dimension.csv', 'r') as f:
    lines = f.readlines()
    for i, line in enumerate(lines):
        for group, selected_label in enumerate(selected_labels):
            if selected_label in line:
                if group in grouped_labels:
                    grouped_labels[group].append(line.split(',')[0])
                else:
                    grouped_labels[group] = [line.split(',')[0]]
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(grouped_labels)

{   0: ['regulatory--stop--g1'],
    3: [   'regulatory--maximum-speed-limit-25--g2',
           'regulatory--maximum-speed-limit-30--g3',
           'regulatory--maximum-speed-limit-35--g2',
           'regulatory--maximum-speed-limit-40--g3',
           'regulatory--maximum-speed-limit-45--g3',
           'regulatory--maximum-speed-limit-55--g2',
           'regulatory--maximum-speed-limit-65--g2',
           'regulatory--no-turn-on-red--g1'],
    4: [   'regulatory--go-straight--g3',
           'regulatory--go-straight-or-turn-left--g2',
           'regulatory--left-turn-yield-on-green--g1',
           'regulatory--one-way-left--g2',
           'regulatory--one-way-right--g2',
           'regulatory--turn-left--g2',
           'regulatory--turn-right--g3']}


In [7]:
mtsd_label_to_shape_index = {}
for key in grouped_labels:
    for sign in grouped_labels[key]:
        mtsd_label_to_shape_index[sign] = key
mtsd_label_to_shape_index

 'regulatory--go-straight--g3': 4,
 'regulatory--go-straight-or-turn-left--g2': 4,
 'regulatory--left-turn-yield-on-green--g1': 4,
 'regulatory--one-way-left--g2': 4,
 'regulatory--one-way-right--g2': 4,
 'regulatory--turn-left--g2': 4,
 'regulatory--turn-right--g3': 4,
 'regulatory--maximum-speed-limit-25--g2': 3,
 'regulatory--maximum-speed-limit-30--g3': 3,
 'regulatory--maximum-speed-limit-35--g2': 3,
 'regulatory--maximum-speed-limit-40--g3': 3,
 'regulatory--maximum-speed-limit-45--g3': 3,
 'regulatory--maximum-speed-limit-55--g2': 3,
 'regulatory--maximum-speed-limit-65--g2': 3,
 'regulatory--no-turn-on-red--g1': 3,
 'regulatory--stop--g1': 0}

In [8]:
# no need to run again since dataset has already been created and saved in /data/shared/...

for split in ['train', 'val', 'test']:
    data_dir = '/data/shared/mtsd_v2_fully_annotated/'
    img_path = join(data_dir, split)
    label_path = join(data_dir, 'annotations')

    image_keys = []
    for entry in os.scandir(img_path):
        if (entry.path.endswith(".jpg")
                or entry.path.endswith(".png")) and entry.is_file():
            image_keys.append(entry.name.split('.jpg')[0])

    for image_key in tqdm(image_keys):
        anno = load_annotation(label_path, image_key)

        with Image.open(os.path.join(img_path, '{:s}.jpg'.format(image_key))) as img:
            img = img.convert('RGBA')
        for obj in anno['objects']:
            class_name = obj['label']

            if class_name not in mtsd_label_to_shape_index:
                continue

            x1 = obj['bbox']['xmin']
            y1 = obj['bbox']['ymin']
            x2 = obj['bbox']['xmax']
            y2 = obj['bbox']['ymax']

            width_change = 0.1 * (x2 - x1)
            height_change = 0.1 * (y2 - y1)

            x1 = x1 - width_change/2
            x2 = x2 + width_change/2

            y1 = y1 - height_change/2
            y2 = y2 + height_change/2

            img_cropped = img.crop((x1, y1, x2, y2))

            img_cropped_resized = img_cropped.resize((64, 64))

            save_dir = '/data/shared/mtsd_v2_fully_annotated/'
            save_dir = join(save_dir, '{}_cropped_signs'.format(split))

            shape_index = mtsd_label_to_shape_index[class_name]

            save_dir = join(save_dir, str(shape_index))
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            save_dir = join(save_dir, image_key + '.jpg')

            img_cropped_resized = img_cropped_resized.convert('RGB')
            img_cropped_resized = img_cropped_resized.save(save_dir)


  0%|          | 0/36589 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [10]:
BATCH_SIZE = 256

In [11]:
# getting mean and std of each channel
train_data = torchvision.datasets.ImageFolder(root='/data/shared/mtsd_v2_fully_annotated/train_cropped_signs/', transform=torchvision.transforms.ToTensor())
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, num_workers=0, shuffle=False)

mean = 0.
std = 0.
for images, _ in train_data_loader:
    batch_samples = images.size(0)
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)

mean /= len(train_data_loader.dataset)
std /= len(train_data_loader.dataset)

print('mean', mean)
print('std', std)

mean tensor([0.4160, 0.3795, 0.2776])
std tensor([0.1707, 0.1565, 0.1463])


### Dataloaders

In [12]:
TRANSFORM_IMG = torchvision.transforms.Compose([
    torchvision.transforms.RandomAffine(degrees=20, scale=(.9, 1.1), shear=0),
    torchvision.transforms.RandomResizedCrop(64, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=mean,
                                     std=std)
])

TRANSFORM_IMG_VALIDATION = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=mean,
                                     std=std)
])

In [13]:
train_data = torchvision.datasets.ImageFolder(root='/data/shared/mtsd_v2_fully_annotated/train_cropped_signs/', transform=TRANSFORM_IMG)
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

In [14]:
val_data = torchvision.datasets.ImageFolder(root='/data/shared/mtsd_v2_fully_annotated/val_cropped_signs/', transform=TRANSFORM_IMG_VALIDATION)
val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

In [15]:
import torch.nn as nn
import torchvision.models as models

In [16]:
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(512, 5)

In [17]:
if torch.cuda.is_available():
    model = model.cuda()

In [18]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)

In [22]:
epochs = 100
min_valid_loss = np.inf
 
for e in range(epochs):
    train_loss = 0.0
    correct_train = 0
    for data, labels in train_data_loader:
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
         
        optimizer.zero_grad()
        logits = model(data)
        
        preds = torch.argmax(logits, dim=1)
        correct_train += (preds == labels).float().sum()

        loss = criterion(logits, labels)
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item()
     
    valid_loss = 0.0
    correct_val = 0
    model.eval()
    for data, labels in val_data_loader:
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
         
        logits = model(data)
        preds = torch.argmax(logits, dim=1)
        correct_val += (preds == labels).float().sum()
        
        loss = criterion(logits, labels)
        valid_loss += loss.item()
     
    print('Epoch: {}'.format(e))
    
    train_accuracy = 100 * correct_train / len(train_data)
    val_accuracy = 100 * correct_val / len(val_data)

    print('Training Accuracy: {}'.format(train_accuracy))
    print('Validation Accuracy: {}'.format(val_accuracy))
    print()
    if min_valid_loss > valid_loss:
        min_valid_loss = valid_loss
         
        # Saving State Dict
        torch.save(model.state_dict(), '/home/nab_126/adv-patch-bench/model_weights/resnet18.pth')

Epoch: 0
Training Accuracy: 95.7848129272461
Validation Accuracy: 97.87234497070312

Epoch: 1
Training Accuracy: 95.55193328857422
Validation Accuracy: 97.70867919921875

Epoch: 2
Training Accuracy: 95.68001556396484
Validation Accuracy: 97.46318054199219

Epoch: 3
Training Accuracy: 95.68001556396484
Validation Accuracy: 97.70867919921875

Epoch: 4
Training Accuracy: 95.69165802001953
Validation Accuracy: 97.54501342773438

Epoch: 5
Training Accuracy: 95.76152801513672
Validation Accuracy: 97.95417785644531

Epoch: 6
Training Accuracy: 95.81974792480469
Validation Accuracy: 98.11784362792969

Epoch: 7
Training Accuracy: 95.92454528808594
Validation Accuracy: 95.2536849975586

Epoch: 8
Training Accuracy: 96.16907501220703
Validation Accuracy: 98.19967651367188

Epoch: 9
Training Accuracy: 95.73823547363281
Validation Accuracy: 97.79051208496094

Epoch: 10
Training Accuracy: 95.69165802001953
Validation Accuracy: 97.46318054199219

Epoch: 11
Training Accuracy: 95.9944076538086
Validatio

Epoch: 95
Training Accuracy: 96.87936401367188
Validation Accuracy: 98.36334228515625

Epoch: 96
Training Accuracy: 96.69306182861328
Validation Accuracy: 98.11784362792969

Epoch: 97
Training Accuracy: 96.86772155761719
Validation Accuracy: 98.36334228515625

Epoch: 98
Training Accuracy: 96.78620910644531
Validation Accuracy: 98.44517517089844

Epoch: 99
Training Accuracy: 96.646484375
Validation Accuracy: 98.36334228515625

