In [None]:
import sys
import os
import torch
from torch import optim
from time import time

In [None]:
tile2vec_dir = '/atlas/u/swang/software/GitHub/tile2vec'
sys.path.append('../')
sys.path.append(tile2vec_dir)

In [None]:
from src.datasets import TileTripletsDataset, GetBands, RandomFlipAndRotate, ClipAndScale, ToFloatTensor, triplet_dataloader
from src.tilenet import make_tilenet

In [None]:
from src.training import prep_triplets, train_triplet_epoch

# Step 1. Download triplets from bucket

Using the download link, unzip triplets into the directory /tile2vec/data/triplets.

# Step 2. Set up dataloader

In [None]:
# Environment stuff
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
cuda = torch.cuda.is_available()

Set up the dataloader for training.

In [None]:
# Change these arguments to match your directory and desired parameters
img_type = 'naip'
tile_dir = '/atlas/u/swang/GitHub/tile2vec/data/triplets/'
bands = 4
augment = True
batch_size = 50
shuffle = True
num_workers = 4
n_triplets = 100000

In [None]:
dataloader = triplet_dataloader(img_type, tile_dir, bands=bands, augment=augment,
                                batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, 
                                n_triplets=n_triplets, pairs_only=True)
print('Dataloader set up complete.')

# Step 3. Set up TileNet

In [None]:
in_channels = bands
z_dim = 512

In [None]:
TileNet = make_tilenet(in_channels=in_channels, z_dim=z_dim)
TileNet.train()
if cuda: TileNet.cuda()
print('TileNet set up complete.')

Set up optimizer.

In [None]:
lr = 1e-3
optimizer = optim.Adam(TileNet.parameters(), lr=lr, betas=(0.5, 0.999))

# Step 4. Train model!

In [None]:
epochs = 50
margin = 10
l2 = 0.01
print_every = 10000
save_models = False

Define the directory for saving models.

In [None]:
model_dir = '/atlas/u/swang/GitHub/tile2net/models/'
if not os.path.exists(model_dir): os.makedirs(model_dir)

In [None]:
t0 = time()
with open(results_fn, 'w') as file:

    print('Begin training.................')
    for epoch in range(0, epochs):
        (avg_loss, avg_l_n, avg_l_d, avg_l_nd) = train_triplet_epoch(
            TileNet, cuda, dataloader, optimizer, epoch+1, margin=margin, l2=l2,
            print_every=print_every, t0=t0)

In [None]:
# Save model after last epoch
if save_models:
    model_fn = os.path.join(model_dir, 'TileNet_epoch50.ckpt')
    torch.save(TileNet.state_dict(), model_fn)