In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

import os
os.environ["CUDA_VISIBLE_DEVICES"]= "1"
from utils import *
from models import *
from data import *
from train_helper import *

torch.backends.cudnn.benchmark = True

PATH = os.getcwd()
DATA_PATH = "{}/data".format(PATH)

torch.__version__

'0.4.0'

In [2]:
test_df = pd.read_csv("{}/sample_submission.csv".format(DATA_PATH))
test_ids = np.unique(test_df["ImageId"])
print("{} Test images".format(len(np.unique(test_df["ImageId"]))))

all_ids = [j.split("/")[-1].split(".")[0] for j in glob.glob("{}/three_band/*".format(DATA_PATH))]
len(all_ids)

DF = pd.read_csv('data/train_wkt_v4.csv')
train_ids = list(np.unique(DF["ImageId"]))
len(train_ids)

val_ids = ["6100_2_2", "6110_1_2", "6140_3_1", "6160_2_1", "6170_0_4"]
for id_ in val_ids:
    train_ids.remove(id_)
      
print("{} Train images".format(len(train_ids)))
print("{} Validation images".format(len(val_ids)))

429 Test images
20 Train images
5 Validation images


In [3]:
imgs = load_array("imgs_12_band.bc")
masks = load_array("masks_12_band.bc")

In [4]:
def main(arg_name, hps):
    model = Model(hps)
    model.init_optimizer()

    trn_dataset = DatasetDSTL(train_ids, imgs=imgs, masks=masks, classes=hps.classes, oversample=hps.oversample, pick_random_idx=True, samples_per_epoch=hps.samples_per_epoch,
                              which_dataset="train", transform=transforms.Compose([RandomNumpyCrop(hps.crop_size), OwnToNormalizedTensor()]))
    train_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=hps.batch_size, shuffle=True, num_workers=hps.num_workers, pin_memory=True, sampler=None)

    val_dataset = DatasetDSTL(val_ids, imgs=imgs, masks=masks, classes=hps.classes, pick_random_idx=False, samples_per_epoch=5, which_dataset="val", transform=transforms.Compose([
        NumpyResize((3200, 3200)), OwnToNormalizedTensor()]))
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=hps.num_workers, pin_memory=True, sampler=None)
    
    logger = Logger(env=arg_name)
    
    early_stopping = hps.patience*3
    best_jaccard, n_iter, early_stopping_counter, total_time = 0, 0, 0, 0
    lrs = []
    create_dir("{}/weights/{}".format(PATH, arg_name))
    writer = SummaryWriter(comment="_{}".format(arg_name))
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(model.optimizer, factor=0.5, verbose=True, patience=hps.patience, mode="max")
    writer.add_scalar("train/lr", model.optimizer.param_groups[0]["lr"], n_iter)
    for epoch in range(0, hps.n_epochs):
        # train for one epoch
        n_iter, total_time = model.train_epoch(train_loader, epoch, n_iter, total_time, logger, writer)
        validate_time = time.time()

    #     evaluate on validation set
        current_jaccard = model.validate(val_loader, epoch, logger, writer)
        scheduler.step(current_jaccard)
        lrs.append(model.optimizer.param_groups[0]["lr"])
        writer.add_scalar("train/lr", model.optimizer.param_groups[0]["lr"], n_iter)

        # remember best jaccard and save checkpoint
        if current_jaccard > best_jaccard:
            best_jaccard = current_jaccard
            is_best = True
            early_stopping_counter = 0
        else:
            if early_stopping_counter>early_stopping:
                print("Early Stopping")
                break
            early_stopping_counter += 1
            is_best = False

        save_checkpoint({
            "lrs": lrs, 
            'epoch': epoch + 1,
            'arch': hps.net,
            'state_dict': model.net.state_dict(),
            'best_jaccard': best_jaccard,
            'optimizer' : model.optimizer.state_dict(),
        }, is_best, path="weights/{}".format(arg_name))
        total_time += time.time()-validate_time

In [5]:
class_list = ["Buildings", "Misc. Manmade structures", "Road", "Track", "Trees", "Crops", "Waterway",
              "Standing Water", "Vehicle Large", "Vehicle Small"]
# classes = list(range(10))
classes = [0]
classes_string = "_".join([class_list[j] for j in classes])
hps = HyperParams()
hps.update("net=UNet_BN,bn=1,classes={}".format("-".join([str(j) for j in classes])))

pprint(attr.asdict(hps))
main("UNet_BN_classes_{}_crop80_lr0.1_noaugm_logloss0.1_dice_0.9_bs256_sample20k".format("_".join([str(j) for j in classes])), hps)

{'augment_flips': 0,
 'augment_rotations': 0.0,
 'batch_size': 256,
 'bn': 1,
 'classes': [0],
 'crop_size': 80,
 'dice_loss_weight': 0.9,
 'filters_base': 32,
 'log_loss_weight': 0.1,
 'lr': 0.01,
 'n_channels': 12,
 'n_epochs': 1000,
 'net': 'UNet_BN',
 'num_gpu': 1,
 'num_workers': 4,
 'opt': 'sgd',
 'oversample': 0.0,
 'patience': 2,
 'print_freq': 100,
 'samples_per_epoch': 20000,
 'weight_decay': 0.0}


  if uwi[0] != 0:


Epoch: [0] TotalTime: 0.5 mins,  BatchTime: 0.383,  DataTime: 0.017,  Loss: 0.8159,  Jaccard: 0.3090
 * Loss 0.790 Jaccard 0.229 (Validation)
Epoch: [1] TotalTime: 1.1 mins,  BatchTime: 0.320,  DataTime: 0.014,  Loss: 0.6247,  Jaccard: 0.4977
 * Loss 0.825 Jaccard 0.206 (Validation)
Epoch: [2] TotalTime: 1.7 mins,  BatchTime: 0.321,  DataTime: 0.016,  Loss: 0.6050,  Jaccard: 0.5291
 * Loss 0.725 Jaccard 0.549 (Validation)
Epoch: [3] TotalTime: 2.2 mins,  BatchTime: 0.321,  DataTime: 0.015,  Loss: 0.5880,  Jaccard: 0.5674
 * Loss 0.741 Jaccard 0.476 (Validation)
Epoch: [4] TotalTime: 2.8 mins,  BatchTime: 0.329,  DataTime: 0.022,  Loss: 0.5784,  Jaccard: 0.5910
 * Loss 0.755 Jaccard 0.440 (Validation)
Epoch: [5] TotalTime: 3.4 mins,  BatchTime: 0.329,  DataTime: 0.023,  Loss: 0.5814,  Jaccard: 0.5821
 * Loss 0.722 Jaccard 0.561 (Validation)
Epoch: [6] TotalTime: 3.9 mins,  BatchTime: 0.324,  DataTime: 0.017,  Loss: 0.5695,  Jaccard: 0.6130
 * Loss 0.728 Jaccard 0.548 (Validation)
Epoch: