In [1]:
import os
import shutil
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

In [2]:
BASE_DIR = '..'
COUNTRY = 'malawi_2016'
METRIC = 'FCS'
RANDOM_SEED = 7

CNN_TRAIN_IMAGE_DIR = os.path.join(BASE_DIR, 'data', 'cnn_images', COUNTRY, METRIC)
CNN_SAVE_DIR = os.path.join(BASE_DIR, 'models', COUNTRY, METRIC)

# groups to cut distribution into
NUMBER_OF_BINS = 4 

# reduce if memory errors on CUDA
BATCH_SIZE = 8

# Number of epochs to train for
# after epoch 5, the model will update the entire network (not just the newly initialized ones)
TOTAL_EPOCHS = 30
# if script notices existing models at earlier epochs, it will load that and set this variable
CURRENT_EPOCH = 0

COUNTRIES_DIR = os.path.join(BASE_DIR, 'data', 'countries')
PROCESSED_DIR = os.path.join(COUNTRIES_DIR, COUNTRY, 'processed')

In [3]:
import sys
sys.path.append(BASE_DIR)
from utils import merge_on_lat_lon

In [4]:
os.makedirs(CNN_TRAIN_IMAGE_DIR, exist_ok=True)
os.makedirs(CNN_SAVE_DIR, exist_ok=True)

# Preprocess
After doing this once, you can skip to the training if the script broke

In [5]:
df_download = pd.read_csv(os.path.join(PROCESSED_DIR, 'image_download_locs.csv'))
downloaded = os.listdir(os.path.join(COUNTRIES_DIR, 'malawi_2016', 'images'))

print(f"actually downloaded: {len(downloaded)}, expected: {len(df_download)}")

# drops those not downloaded
df_download['row'] = np.arange(len(df_download))
idx_not_download = df_download.set_index('image_name').drop(downloaded)['row'].values.tolist()
df_download.drop(idx_not_download, inplace=True)
df_download.drop('row', axis=1, inplace=True)
df_download.reset_index(drop=True, inplace=True)

actually downloaded: 19472, expected: 19500


In [6]:
def assign_bin(cutoffs):
    '''
    Returns a function that takes a scalar value x and assigns it to a bin based on 
    the cutoffs given to the "parent" function
    '''
    def binning_function(x):
        # inner_function is still aware of variable cutoffs
        for i in range(len(cutoffs) - 1):
            if (x >= cutoffs[i]).any() and (x < cutoffs[i + 1]).any():
                return i
        raise ValueError(f'Given value {x} is outside the cutoffs')
    return binning_function

def create_bin(df, metric):
    '''
    df: dataframe with column metric
    
    Uses a quantile cut to bin the metric of interest into NUMBER_OF_BINS equally-represented categories
    Also identifies the images that are near the lower and upper cutoffs
    
    Adds columns 'bin', 'near_upper', and 'near_lower' to df
    '''
    np.random.seed(RANDOM_SEED)
    frac_lower = 0.1 # lower 10% of a bin's range will count as being "near"
    frac_upper = 0.1 # upper 10% of a bin's range will count as being "near"
    bins, bin_cutoffs = pd.qcut(df[metric], NUMBER_OF_BINS, retbins=True)
    df['bin'] = bins.cat.codes
    df['bin'] = df['bin'].astype(np.int64)
    df['near_lower'] = False
    df['near_upper'] = False
    for i in range(1, len(bin_cutoffs) - 1):
        span = bin_cutoffs[i + 1] - bin_cutoffs[i]
        if i != 0:
            # we take the minimum of the current bin and the bin 
            # we want to join to as the effective span
            # this prevents a bin with very large span from dominating
            span = min(span, bin_cutoffs[i] - bin_cutoffs[i - 1])
        lower_c = bin_cutoffs[i] + frac_lower * span
        df['near_lower'].loc[(df['bin'] == i) & (df[metric] < lower_c)] = True
    for i in range(0, len(bin_cutoffs) - 2):
        span = bin_cutoffs[i + 1] - bin_cutoffs[i]
        if i != len(bin_cutoffs) - 2:
            # we take the minimum of the current bin and the bin 
            # we want to join to as the effective span
            # this prevents a bin with very large span from dominating
            span = min(span, bin_cutoffs[i + 2] - bin_cutoffs[i + 1])
        upper_c = bin_cutoffs[i + 1] - frac_upper * span
        df['near_upper'].loc[(df['bin'] == i) & (df[metric] > upper_c)] = True

def symlink_images(df_images):
    '''
    df_images: dataframe with 'image_name', 'country', 'is_train' columns
    
    This function will symlink (a type of link that takes very little space and points to another link)
    the images into "train" and "valid" folders in CNN_TRAIN_IMAGE_DIR
    Symlinking prevents us from having to copy the images, which saves disk space and time. From a user's
    perspective, opening the symlinked file opens the actual hard link file elsewhere. This means
    our CNN training can operate on a directory of symlinked images without any problem/knowledge of
    symlinks because this function is supported natively by the filesystem. 
    In this case, the original hard link is in the original download directory at COUNTRIES_DIR/<country>/images.
    THAT ORIGINAL DOWNLOAD DIRECTORY CANNOT BE MOVED OR MODIFIED OR SCRIPTS WILL BREAK
    '''
    train = df_images[df_images['is_train']]
    valid = df_images[~df_images['is_train']]
    
    # uses symlinking to save disk space
    print('symlinking train images')
    for im_name, country in tqdm(zip(train['image_name'], train['country']), total=len(train)):
        src = os.path.abspath(os.path.join(COUNTRIES_DIR, country, 'images', im_name))
        dest = os.path.join(CNN_TRAIN_IMAGE_DIR, 'train', im_name)
        if os.system(f"ln -s {src} {dest}") != 0:
            print("error creating symlink")
            raise ValueError(src, dest)

    print('symlinking valid images')
    for im_name, country in tqdm(zip(valid['image_name'], valid['country']), total=len(valid)):
        src = os.path.abspath(os.path.join(COUNTRIES_DIR, country, 'images', im_name))
        dest = os.path.join(CNN_TRAIN_IMAGE_DIR, 'valid', im_name)
        if os.system(f"ln -s {src} {dest}") != 0:
            print("error creating symlink")
            raise ValueError(src, dest)
    return

def preprocess_country(df, frac=0.7):
    '''
    df: dataframe with cluster_lat, cluster_lon, METRIC
    frac: represents the percent of clusters to use for training
    
    saves the images (symlinked) to CNN_TRAIN_IMAGE_DIR/COUNTRY/
    saves the dataframe to PROCESSED_DIR/METRIC.csv
    '''
    savepath = os.path.join(PROCESSED_DIR, f'{METRIC}.csv')
    if os.path.exists(savepath):
        print("already processed this country")
        df_images = pd.read_csv(savepath)
        return df_images
    np.random.seed(RANDOM_SEED)
    df_images = df[df['country'] == COUNTRY].copy()
    unique_clusters = df_images[['cluster_lat', 'cluster_lon']].drop_duplicates()
    shuffled_clusters = unique_clusters.sample(frac=1)
    num_train = int(frac * len(shuffled_clusters))
    train_clusters = shuffled_clusters[:num_train]
    train_clusters['is_train'] = True
    df_images = merge_on_lat_lon(df_images, train_clusters, how='left')
    # if not marked as true, will be NA (aka a validation cluster)
    df_images['is_train'].fillna(False, inplace=True)
    create_bin(df_images, METRIC)
    
    os.makedirs(os.path.join(CNN_TRAIN_IMAGE_DIR, 'train'), exist_ok=False)
    os.makedirs(os.path.join(CNN_TRAIN_IMAGE_DIR, 'valid'), exist_ok=False)

    symlink_images(df_images)
    
    # save to disk
    df_images.to_csv(savepath, index=False)
    return df_images

Split images into train/valid.
Each cluster will contribute 80% of images for training, and 20% for validation.

In [7]:
df_images = preprocess_country(df_download)

already processed this country


# Train Model
Adapted from the PyTorch CNN training tutorial.

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from PIL import Image

In [9]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda', index=0)

In [10]:
def initialize_model():
    global CURRENT_EPOCH
    model = None
    input_size = 224 # hardcoded for VGG, our network
    existing = os.listdir(CNN_SAVE_DIR)
    found = False
    if len(existing) != 0:
        largest_epoch = 0
        prefix = f'trained_model_{METRIC}_epoch_'
        for f in existing:
            if f[:len(prefix)] != prefix:
                continue
            found = True
            string = f.split('.')[0] # remove extension
            epoch = int(string[len(prefix):]) # parse out the epoch
            if epoch > largest_epoch:
                largest_epoch = epoch
                found = True
        if found:
            print(f'using existing model at epoch {largest_epoch}')
            CURRENT_EPOCH = largest_epoch + 1
            path = os.path.join(CNN_SAVE_DIR, prefix + str(largest_epoch) + '.pt')
            model = torch.load(path, map_location=DEVICE)
    if not found:
        print('initializing new model')
        torch.manual_seed(RANDOM_SEED)
        model = models.vgg11_bn(pretrained=True)
        # turn off training for all existing paramaters (for now)
        for param in model.parameters():
            param.requires_grad = False
        num_ftrs = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_ftrs, NUMBER_OF_BINS)
        model = model.to(DEVICE)
    return model, input_size

model, input_size = initialize_model()
optimizer = optim.Adam(model.parameters(), lr=3e-6)

using existing model at epoch 14


In [11]:
# we will query this to figure out the correct label
DF_LOOKUP = df_images.set_index('image_name')

In [12]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}
print("Initializing Datasets and Dataloaders...")

class ForwardPassDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, transformer):
        self.image_dir = image_dir
        self.image_list = os.listdir(self.image_dir)
        self.transformer = transformer

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        image_name = self.image_list[index]

        # Load image
        X = self.filename_to_im_tensor(self.image_dir + '/' + image_name)
        y = DF_LOOKUP.loc[image_name]['bin']
        
        return X, y, image_name
    
    def filename_to_im_tensor(self, file):
        im = (plt.imread(file)[:,:,:3] * 256).astype(np.uint8)
        im = Image.fromarray(im)
        im = self.transformer(im)
        return im

# Create training and validation datasets
image_datasets = {x: ForwardPassDataset(os.path.join(CNN_TRAIN_IMAGE_DIR, x), 
                                          data_transforms[x]) for x in ['train', 'valid']}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], 
                                                   batch_size=BATCH_SIZE, 
                                                   shuffle=True, 
                                                   num_workers=4) for x in ['train', 'valid']}

Initializing Datasets and Dataloaders...


In [13]:
class CustomCriterion:
    '''
    This custom criterion will allow images that are near the border of two bins
    to calculate their loss partially based on the bin they are close to
    '''
    def __init__(self, alpha=0.75):
        # alpha describes what percent should go to the correct class
        # if the image is near_lower or near_upper
        self.criterion = nn.CrossEntropyLoss()
        self.alpha = alpha
    
    def __call__(self, outputs, labels, image_names):
        ret = None
        for i in range(len(image_names)):
            lookup = DF_LOOKUP.loc[image_names[i]]
            output = outputs[i].reshape(1, -1)
            label = labels[i].reshape(1)
           
            if lookup['near_upper']:
                # the +1 on the second line shifts the criteria to the upper bin
                iret = self.alpha * self.criterion(output, label) + \
                        (1 - self.alpha) * self.criterion(output, label + 1)
            elif lookup['near_lower']:
                # the -1 on the second line shifts the criteria to the lower bin
                iret = self.alpha * self.criterion(output, label) + \
                        (1 - self.alpha) * self.criterion(output, label - 1)
            else:
                iret = self.criterion(output, label) # regular cross entropy
            if ret is None:
                ret = iret
            else:
                ret += iret
        return ret / len(image_names) # averaged

In [14]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs):
    global CURRENT_EPOCH, DEVICE
    since = time.time()
    val_acc_history = []
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(CURRENT_EPOCH, num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        if epoch == 5:
            # fine tune whole model now
            for param in model.parameters():
                param.requires_grad = True

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels, image_names in tqdm(dataloaders[phase]):
                inputs = inputs.to(DEVICE)
                labels = labels.to(DEVICE)

                # zero the parameter gradients
                optimizer.zero_grad()
                # track gradients in train phase only
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels, image_names)
                    _, preds = torch.max(outputs, 1)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            # deep copy the model if it is better
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'valid':
                val_acc_history.append(epoch_acc)

        if epoch % 5 == 4:
            # save intermediate results in case script breaks
            savepath = os.path.join(CNN_SAVE_DIR, f'trained_model_{METRIC}_epoch_{epoch}.pt')
            torch.save(model, savepath)
        
        # end one epoch
        CURRENT_EPOCH += 1
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    
    # load best model
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

In [15]:
criterion = CustomCriterion()
model, hist = train_model(model, dataloaders_dict, criterion, optimizer, TOTAL_EPOCHS)

Epoch 15/29
----------


HBox(children=(FloatProgress(value=0.0, max=1704.0), HTML(value='')))


train Loss: 1.1730 Acc: 0.4581


HBox(children=(FloatProgress(value=0.0, max=731.0), HTML(value='')))


valid Loss: 1.2817 Acc: 0.3821

Epoch 16/29
----------


HBox(children=(FloatProgress(value=0.0, max=1704.0), HTML(value='')))


train Loss: 1.1696 Acc: 0.4600


HBox(children=(FloatProgress(value=0.0, max=731.0), HTML(value='')))


valid Loss: 1.2984 Acc: 0.3692

Epoch 17/29
----------


HBox(children=(FloatProgress(value=0.0, max=1704.0), HTML(value='')))


train Loss: 1.1594 Acc: 0.4682


HBox(children=(FloatProgress(value=0.0, max=731.0), HTML(value='')))




KeyboardInterrupt: 

In [None]:
savepath = os.path.join(CNN_SAVE_DIR, f'trained_model_{METRIC}.pt')
if os.path.isfile(savepath):
    print('A model is already saved at this location')
else:
    print(f'Saving model to {savepath}')
    torch.save(model, savepath)