In [1]:
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
try:
    from torchvision import transforms, utils
except:
    !conda install --yes torchvision --no-channel-priority
    from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
try:
    from torchvision import transforms, utils
except:
    !pip install torchvision
    from torchvision import transforms, utils
from PIL import Image

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

# You want to change these to be your own filenames
csv_file = 'flood_sample_metadata.csv'
label_csv = 'flood_sample_label.csv'

In [2]:
pd.options.display.max_rows = 1000
pd.options.display.max_columns = 50

In [3]:
# read the tsv file with the labels
ladi_responses = pd.read_csv("http://ladi.s3-us-west-2.amazonaws.com/Labels/ladi_aggregated_responses_url.tsv",delimiter='\t',header='infer')
ladi_responses.head(1000)

Unnamed: 0,url,WorkerId,Answer
0,https://ladi.s3-us-west-2.amazonaws.com/Images...,0,['damage:flood/water']
1,https://ladi.s3-us-west-2.amazonaws.com/Images...,1,['damage:flood/water']
2,https://ladi.s3-us-west-2.amazonaws.com/Images...,2,['damage:flood/water']
3,https://ladi.s3-us-west-2.amazonaws.com/Images...,20,['damage:flood/water']
4,https://ladi.s3-us-west-2.amazonaws.com/Images...,0,['damage:flood/water']
5,https://ladi.s3-us-west-2.amazonaws.com/Images...,8,"['damage:flood/water', 'damage:washout']"
6,https://ladi.s3-us-west-2.amazonaws.com/Images...,9,"['damage:rubble', 'damage:flood/water']"
7,https://ladi.s3-us-west-2.amazonaws.com/Images...,0,['damage:rubble']
8,https://ladi.s3-us-west-2.amazonaws.com/Images...,1,['damage:none']
9,https://ladi.s3-us-west-2.amazonaws.com/Images...,11,['damage:flood/water']


In [6]:
# Strip off bracket and comma from the Answer catagory
ladi_responses["Answer"] = ladi_responses["Answer"].str.strip('[|]')
# split list of responses into multiple rows
ladi_responses["Answer"] = ladi_responses["Answer"].str.split(",",expand = True)
# remove the single quote character from either end of string
ladi_responses["Answer"] = ladi_responses["Answer"].str.strip('\'')
# add a column to help with aggregation when pivoting
ladi_responses["response_count"] = 1
# Create a matrix with the number of workers who answered given label for given image
# using pivot table; filling in nan values with 0
label_matrix = ladi_responses.pivot_table(values='response_count', 
                                          index='url', 
                                          columns='Answer', 
                                          aggfunc='sum',
                                          fill_value=0)
label_matrix

Answer,damage:flood/water,damage:landslide,damage:misc,damage:none,damage:rubble,damage:smoke/fire,damage:washout,environment:dirt,environment:grass,environment:lava,environment:none,environment:rock,environment:sand,environment:shrubs,environment:snow/ice,environment:trees,infrastructure:bridge,infrastructure:building,infrastructure:communications-tower,infrastructure:dam-levee,infrastructure:none,infrastructure:pipe,infrastructure:railway,infrastructure:road,infrastructure:utility-line,infrastructure:water-tower,vehicle:aircraft,vehicle:boat,vehicle:car,vehicle:none,vehicle:truck,water:flooding,water:lake,water:none,water:ocean,water:puddle,water:river
url,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/1013/20145/DSC_0020_e34a1edc-6d5c-472e-847e-89dac3ed4519.jpg,2,0,0,4,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/1013/20145/DSC_0028_18dcd0d8-4b79-452e-9ade-604d4f13ddfd.jpg,3,0,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/1013/20145/DSC_0035_add5632e-eec2-42a4-a7db-8c42871164c2.jpg,1,0,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/1013/20145/DSC_0042_970b36fb-582b-4b51-a581-923efb394278.jpg,2,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/1013/20147/DSC_1575_787dc8f2-fb8f-4464-99a2-45ba5fc677c2.jpg,2,0,0,1,0,0,0,0,0,0,0,0,0,0,0,3,0,3,0,0,0,0,0,0,0,0,0,0,3,0,0,3,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/9168/616298/DSC_3990_bdd98b00-d138-4daa-b842-4a4266cb2de9.jpg,0,1,3,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/9168/616298/DSC_3997_d3d2337b-b3e3-468c-8e2c-93ba34f9a219.jpg,1,0,1,5,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/9168/616298/DSC_4000_af3e8e76-fbba-4a6f-8aa7-44cd7c35fdaf.jpg,0,0,1,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/9168/616298/DSC_4005_1dd9e503-51f1-4857-ad95-6aec4abfe521.jpg,0,2,1,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [10]:
# damage_matrix[damage_matrix['damage:washout'] > 0].shape[0]

# floodwater: 14K
# landslide: 800, remove
# misc: remove
# none: 34K
# rubble: 4K
# smoke/fire: 1.2K
# washout: 674, remove


In [7]:
# if none is greater than each of the other labels, keep none
# else, drop none

# if anything at end is > 0, that is a label.
labels = ['damage:flood/water', 'damage:rubble', 'damage:smoke/fire']
def proc(row):
    
    
    arr = [row[x] for x in labels]
    if row['damage:none'] > max(arr):
        for x in labels:
            row[x] = 0
        row['damage:none'] = 1
    else:
        row['damage:none'] = 0
        
        for x in labels:
            if row[x] > 0: row[x] = 1
            
    return row
        
    

damage_matrix = label_matrix[labels + ['damage:none']]
final_mat = damage_matrix.apply(proc, axis = 1)

In [8]:
final_mat

Answer,damage:flood/water,damage:rubble,damage:smoke/fire,damage:none
url,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/1013/20145/DSC_0020_e34a1edc-6d5c-472e-847e-89dac3ed4519.jpg,0,0,0,1
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/1013/20145/DSC_0028_18dcd0d8-4b79-452e-9ade-604d4f13ddfd.jpg,1,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/1013/20145/DSC_0035_add5632e-eec2-42a4-a7db-8c42871164c2.jpg,0,0,0,1
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/1013/20145/DSC_0042_970b36fb-582b-4b51-a581-923efb394278.jpg,1,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/1013/20147/DSC_1575_787dc8f2-fb8f-4464-99a2-45ba5fc677c2.jpg,1,0,0,0
...,...,...,...,...
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/9168/616298/DSC_3990_bdd98b00-d138-4daa-b842-4a4266cb2de9.jpg,0,0,0,1
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/9168/616298/DSC_3997_d3d2337b-b3e3-468c-8e2c-93ba34f9a219.jpg,0,0,0,1
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/9168/616298/DSC_4000_af3e8e76-fbba-4a6f-8aa7-44cd7c35fdaf.jpg,0,0,0,1
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/9168/616298/DSC_4005_1dd9e503-51f1-4857-ad95-6aec4abfe521.jpg,0,0,0,1


In [9]:
# Load ladi_images_metadata.csv
metadata = pd.read_csv('http://ladi.s3-us-west-2.amazonaws.com/Labels/ladi_images_metadata.csv')

In [10]:
sample_size=500
# none_sample = final_mat[final_mat['damage:none'] == 1].sample(sample_size)
# none_sample.shape

samples = final_mat[final_mat['damage:flood/water'] == 1].sample(sample_size)
for col in final_mat.columns[1:]:
#     print(col)
#     print(final_mat[final_mat[col] == 1].shape)
#     print(final_mat[final_mat[col] == 1].sample(sample_size))
    samples = samples.append(final_mat[final_mat[col] == 1].sample(sample_size))
#     print(samples)
samples = samples.reset_index().drop_duplicates(subset='url', keep='first').set_index('url')
# df = pd.DataFrame()
# for i, s in enumerate(samples):
#     df = df.merge(s, how='outer', left_index=True, right_index=True)
samples.shape

(1972, 4)

In [11]:
samples.head(500)

Answer,damage:flood/water,damage:rubble,damage:smoke/fire,damage:none
url,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/9073/613891/DSC_4635_f6181912-660c-45b5-bfb2-3da64322754c.jpg,1,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/3043/291612/DSC_7137_b9ba9c2c-9c60-4c9a-80d3-fd15fafa556b.jpg,1,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/9075/614080/A0001_AP_c9e3d19a-5cc3-463a-b13b-bbbc765af45d.jpg,1,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/2024/100502/010_1453_f96fbcca-8643-4533-a332-b7016f092ad9.jpg,1,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/3045/291518/DSC_1099_57e91316-8e7b-48f1-86a2-09b7e269c521.jpg,1,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/3033/231171/DSC_1032_beae966c-a218-43f2-afc4-ae7c3740ae40.jpg,1,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/9073/613546/DSC_0881_729b0c80-b1a5-4b9e-8788-61f1da4a9e6d.jpg,1,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/9075/614066/A0007_AP_8ec03ded-74eb-476a-95c2-e9a023af4bf6.jpg,1,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/1013/20186/CAP_2123_a8009b2c-d3af-4d04-aeef-17f5c49e2f4e.jpg,1,0,0,0
https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/9074/613811/DSC_0730_a3890e34-c437-4fb5-ad1c-24000a521b3c.jpg,1,0,0,0


In [12]:
metadata_csv = 'damage_sample_metadata.csv'
label_csv = 'damage_sample_label.csv'

In [13]:
# Load ladi_images_metadata.csv
metadata = pd.read_csv('http://ladi.s3-us-west-2.amazonaws.com/Labels/ladi_images_metadata.csv')

# sampling
# sample_size=1000
# flood_sample = flood_examples.sample(sample_size)
# non_flood_sample = non_flood_examples.sample(sample_size)

# creating a df with True/False labels for flooding
# training_flood = pd.DataFrame(index=flood_sample.index, data={'label':True}).reset_index()
# training_non_flood = pd.DataFrame(index=non_flood_sample.index, data={'label':False}).reset_index()
# label_df = pd.concat([training_flood, training_non_flood], ignore_index=True)

samples.to_csv(label_csv)

# create list of urls to download
samples.to_csv('urls_to_download.csv', index=True, header=False)

# # Get flood and non-flood metadata
damage_metadata = metadata[metadata['url'].isin(samples.index)]
# flood_metadata = metadata[metadata['url'].isin(flood_sample.index)]
# not_flood_metadata = metadata[metadata['url'].isin(non_flood_sample.index)]
# training_metadata = pd.concat([flood_metadata, not_flood_metadata], ignore_index=True)

damage_metadata.to_csv(metadata_csv)

In [58]:
!mkdir -p training_images
!wget --content-disposition --trust-server-names -i urls_to_download.csv -P training_images/

404 Not Found
2021-07-30 04:08:49 ERROR 404: Not Found.

--2021-07-30 04:08:49--  https://ladi.s3-us-west-2.amazonaws.com/Images/FEMA_CAP/9073/613551/DSC_0615_733648a6-8006-4a4a-ac2a-a0bec8af4a69.jpg,1,0,0,0
Reusing existing connection to ladi.s3-us-west-2.amazonaws.com:443.
HTTP request sent, awaiting response... ^C


In [19]:
scale = transforms.Resize(768)
crop = transforms.RandomCrop(512)
rotate = transforms.RandomRotation(20)
flip_demo = transforms.RandomHorizontalFlip(1) # flip with 100% chance just to demo
flip = transforms.RandomHorizontalFlip(p=0.5)
composed = transforms.Compose([scale,
                               crop,
                               rotate,
                               flip_demo])

In [16]:
# convenient function for showing the images
def show_image(image):
    plt.imshow(image)
    # pause a bit so that plots are updated
    plt.pause(0.01)

def convert_url_to_local_path(url):
    '''
    gets the location of the downloaded image
    '''
    return 'training_images/'+url.split('/')[-1]

class DamageSampleDataset(Dataset):

    def __init__(self, metadata_csv, label_csv, transform = None):
        """
        Args:
            metadata_csv (string): Path to the csv file with metadata.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.damage_sample_metadata = pd.read_csv(metadata_csv)
        # get the path in the shared directory
        self.damage_sample_metadata['local_path'] = self.damage_sample_metadata['url'].apply(convert_url_to_local_path)
        self.damage_sample_label = pd.read_csv(label_csv)
        self.damage_sample_data = pd.merge(self.damage_sample_metadata, 
                                        self.damage_sample_label,
                                       on="url")
        self.transform = transform
        
    def __len__(self):
        return len(self.damage_sample_metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        ## Load images from local directory. There is no need to redownload images to local machine. ##
        local_path = self.damage_sample_metadata.iloc[idx]['local_path']
        url = self.damage_sample_metadata.iloc[idx]['url']
        try:
            image = Image.fromarray(io.imread(local_path))
            img_name = local_path
        except:
            image = Image.fromarray(io.imread(url))
            img_name = url
        uuid = self.damage_sample_data.iloc[idx, 1]
        timestamp = self.damage_sample_data.iloc[idx, 2]
        gps_lat = self.damage_sample_data.iloc[idx, 3]
        gps_lon = self.damage_sample_data.iloc[idx, 4]
        gps_alt = self.damage_sample_data.iloc[idx, 5]
        file_size = self.damage_sample_data.iloc[idx, 6]
        width = self.damage_sample_data.iloc[idx, 7]
        height = self.damage_sample_data.iloc[idx, 8]
        label = self.damage_sample_data.iloc[idx, -1]
        
        if self.transform:
            image = self.transform(image)

        sample = {'image': image, 
                  'image_name': img_name, 
                  'damage:flood/water': label, 
                  'uuid': uuid, 
                  'timestamp': timestamp, 
                  'gps_lat': gps_lat, 
                  'gps_lon': gps_lon, 
                  'gps_alt': gps_alt, 
                  'orig_file_size': file_size, 
                  'orig_width': width, 
                  'orig_height': height}

        return sample

In [17]:
damage_sample_dataset = DamageSampleDataset(metadata_csv = metadata_csv, label_csv = label_csv)

In [20]:
transformed_dataset = DamageSampleDataset(metadata_csv = metadata_csv, 
                                       label_csv = label_csv, 
                                       transform = transforms.Compose([scale, 
                                                                       crop, 
                                                                       rotate, 
                                                                       flip, 
                                                                       transforms.ToTensor()]
                                                                     )
                                      )

In [55]:
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader

batch_size = 16
test_split_ratio = .1
shuffle_dataset = True
random_seed = 42
# num_workers = 1

# Creating data indices for training and validation splits:
dataset_size = len(transformed_dataset)
indices = list(range(dataset_size))
split = int(np.floor(test_split_ratio * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, test_indices = indices[split:], indices[:split]

# Creating data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

train_loader = torch.utils.data.DataLoader(transformed_dataset, batch_size=batch_size,
                                           sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(transformed_dataset, batch_size=batch_size,
                                                sampler=test_sampler)

In [54]:
len(train_indices)

1775

In [29]:
dataiter = iter(train_loader)
images, labels = dataiter.next()
print(type(images))
print(images.shape)
print(labels.shape)
for i, data in enumerate(train_loader, 0):

ValueError: too many values to unpack (expected 2)

In [23]:
import torch.nn as nn
import torch.nn.functional as F
try:
    from cnn_finetune import make_model
except:
    !pip install cnn-finetune
    from cnn_finetune import make_model

In [30]:
net = make_model('resnet50', num_classes=4, pretrained=True).cuda()

In [31]:
import torch.optim as optim

criterion = nn.BCELoss()
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(net.parameters(), lr=0.0005)

In [49]:
t = torch.tensor([[0,0,0,1],[0,0,0,1],[0,0,0,1],[0,0,0,1],[0,0,0,1],[0,0,0,1],[0,0,0,1],[0,0,0,1],[0,0,0,1],[0,0,0,1],[0,0,0,1],[0,0,0,1],[0,0,0,1],[0,0,0,1],[0,0,0,1],[0,0,0,1]])

In [59]:
a=0

def get_checkpoint_path(epoch):
    return f'epoch_checkpoints/flood_checkpoint_epoch{epoch}.pth'

torch.backends.cudnn.benchmark = True # flag for some GPU optimizations
starting_epoch = 1
additional_epochs = 1
if starting_epoch > 1:
    net.load_state_dict(torch.load(get_checkpoint_path(starting_epoch)))
for epoch in range(starting_epoch, starting_epoch+additional_epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        a=data
        inputs = data['image'].cuda()
        labels = zip(data['damage:flood/water'], data['damage:rubble'], data['damage:smoke/fire'], data['damage:none']).cuda()
#         labels = data[['damage:flood/water','damage:rubble','damage:smoke/fire','damage:none']].cuda()
        labels = t.cuda()
        # casting int to long for loss calculation#
        labels = labels.long()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if (i+1) % 10 == 0:    # print every 10 mini-batches
            print(f'[epoch {epoch}, batch {i +1} ] average loss: {running_loss/10}')
            running_loss = 0.0
    # save the model
    PATH = get_checkpoint_path(epoch)
    torch.save(net.state_dict(), PATH)
print('Finished Training')

KeyError: 'damage:rubble'

In [57]:
print(a)

{'image': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.]

In [None]:
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(test_loader)
single_iter = dataiter.next()
images = single_iter['image']
labels = single_iter['damage:flood/water']

In [None]:
net.load_state_dict(torch.load(PATH))

outputs = net(images.cuda())
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % predicted[j].cpu()
                              for j in range(batch_size)))

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images = data['image'].cuda()
        labels = data['damage:flood/water'].cuda()
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test images: %d %%' % (
    100 * correct / total))

In [None]:
truth_labels = []
predicted_labels = []
with torch.no_grad():
    for data in test_loader:
        images = data['image'].cuda()
        labels = data['damage:flood/water'].cuda()
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        truth_labels.append(labels.cpu())
        predicted_labels.append(predicted.cpu())
truth_labels = np.concatenate([x.numpy() for x in truth_labels])
predicted_labels = np.concatenate([x.numpy() for x in predicted_labels])

In [None]:
import sklearn.metrics
confusion_matrix = sklearn.metrics.confusion_matrix(truth_labels, predicted_labels)
disp = sklearn.metrics.ConfusionMatrixDisplay(confusion_matrix, ['flood','no flood'])
disp.plot()