In [15]:
import sys
import os
import torch
from torch import optim

In [2]:
sys.path.append('../')
sys.path.append('/atlas/u/swang/software/GitHub/tile2vec')

In [12]:
from src.datasets import TileTripletsDataset, GetBands, RandomFlipAndRotate, ClipAndScale, ToFloatTensor, triplet_dataloader
from src.tilenet import make_tilenet

# Step 1. Download triplets from bucket

Using the download link, unzip triplets into the directory /tile2vec/data/triplets.

# Step 2. Set up dataloader

In [4]:
# Environment stuff
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
cuda = torch.cuda.is_available()

Set up the dataloader for training.

In [5]:
# Change these arguments to match your directory and desired parameters
img_type = 'naip'
tile_dir = '/atlas/u/swang/GitHub/tile2vec/data/triplets/'
bands = 4
augment = True
batch_size = 50
shuffle = True
num_workers = 4
n_triplets = 100000

In [6]:
dataloader = triplet_dataloader(img_type, tile_dir, bands=bands, augment=augment,
                                batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, 
                                n_triplets=n_triplets, pairs_only=True)
print('Dataloader set up complete.')

Dataloader set up complete.


# Step 3. Set up TileNet

In [10]:
in_channels = bands
z_dim = 512

In [13]:
TileNet = make_tilenet(in_channels=in_channels, z_dim=z_dim)
TileNet.train()
if cuda: TileNet.cuda()
print('TileNet set up complete.')

TileNet set up complete.


Set up optimizer.

In [17]:
lr = 1e-3
optimizer = optim.Adam(TileNet.parameters(), lr=lr, betas=(0.5, 0.999))

Define the directory for saving models.

In [18]:
model_dir = '/atlas/u/swang/GitHub/tile2net/models/'
if not os.path.exists(model_dir): os.makedirs(model_dir)

# Step 4. Train model!

In [None]:
epochs = 50

In [None]:
t0 = time()
results_fn = os.path.join(model_dir, 'results.txt')
with open(results_fn, 'w') as file:

    print('Begin training.................')
    for epoch in range(0, epochs):
        (avg_loss, avg_l_n, avg_l_d, avg_l_nd) = train_triplet_epoch(
            TileNet, cuda, dataloader, optimizer, epoch+1, margin=margin, l2=l2,
            print_every=print_every, t0=t0)

        # Plot l_n, l_d, and l_nd if available
        if loss_type == 'triplet':
            Y = np.vstack((np.array(vis_data['avg_l_ns']),
                           np.array(vis_data['avg_l_ds']),
                           np.array(vis_data['avg_l_nds']))).T
            legend = ['l_n', 'l_d', 'l_nd']
        elif loss_type == 'cosine':
            Y = np.vstack((np.array(vis_data['avg_l_ns']),
                           np.array(vis_data['avg_l_ds']))).T
            legend = ['l_n', 'l_d']
        vis.line(Y=Y, X=np.array(range(1, epoch+2)), win='losses',
            opts={'legend': legend, 'markers': False,
            'title': 'Loss components', 'xlabel': 'Epoch',
            'ylabel': 'Average loss'})

        # RF comparison with PCA
        X = embed_dataset(cnn, z_dim, cuda, img_type, tile_dir,
            bands=bands, augment=False, batch_size=50,
            shuffle=False, num_workers=4, print_every=None,
            n_triplets=rf_triplets)
        vis_data['avg_z_norms'].append(np.linalg.norm(X, axis=1).mean())
        X = np.delete(X, nan_locs, axis=0)
        X_tr, X_te, X_tr_pca, X_te_pca, y_tr, y_te = train_test_split(X, X_pca, y, shuffle=True)
        rf = RandomForestClassifier()
        rf.fit(X_tr, y_tr)
        rf_acc = rf.score(X_te, y_te)
        rf_accs['tile2vec'].append(rf_acc)
        rf_accs['PCA'].append(rf_accs['PCA'][-1])
        Y = np.vstack((np.array(rf_accs['tile2vec']), rf_accs['PCA'])).T
        vis.line(Y=Y, X=np.array(range(0, epoch+2)), win='rf',
            opts={'legend': ['tile2vec', 'PCA'], 'markers': False,
            'title': 'RF comparison on CDL', 'xlabel': 'Epoch',
            'ylabel': 'Accuracy'})

        print('Writing results for epoch {}'.format(epoch+1))
        file.write('{} {} {} {} {}\n'.format(
            avg_loss, avg_l_n, avg_l_d, avg_l_nd, rf_acc))

        # Plot average norm for embedding
        Y = np.array(vis_data['avg_z_norms'])
        vis.line(Y=Y, X=np.array(range(0, epoch+2)), win='z_norm',
            opts={'title': 'Average embedding norm', 'xlabel': 'Epoch',
            'ylabel': 'Norm'})

        # Plotting PCA of embeddings by CDL class
        X_embed_pca = PCA(2).fit_transform(X)
        vis.scatter(X=X_embed_pca[:1000], Y=y[:1000], win='pca{}'.format(epoch % 3),
            opts={'title': 'Epoch {}: PCA by CDL label'.format(epoch+1),
            'xlabel': 'PC 1', 'ylabel': 'PC 2'})

        # Plot histograms of latent dimensions
        for d_z in range(min(z_dim, 3)):
            vis.histogram(X=X[:,d_z], win='z{}'.format(d_z),
                opts={'numbins': 20, 'title': 'Epoch {}: Distribution of z{}'.format(epoch+1, d_z)})

        # Plot nearest neighbors for embedding and for PCA
        n_samples = 5
        nrow = 8
        k = int(nrow ** 2)
        for i in range(n_samples):
            samples = np.zeros((k, 3, tile_size, tile_size))
            idx = np.random.randint(0, len(tiles))
            # Get embedding and PCA neighbors
            (topk_idxs, topk_dists) = get_k_neighbors(idx, X, int(k/2))
            (topk_idxs_pca, topk_dists_pca) = get_k_neighbors(idx, X_pca, int(k/2))
            X_idx = 0
            X_pca_idx = 0
            for j in range(k):
                if j % nrow < nrow / 2:
                    topk_idx = topk_idxs[X_idx]
                    X_idx += 1
                else:
                    topk_idx = topk_idxs_pca[X_pca_idx]
                    X_pca_idx += 1
                samples[j,:,:,:] = np.moveaxis(tiles[topk_idx][:,:,:3], -1, 0)
            vis.images(samples, win='knn{}_{}'.format(epoch % 5, i),
                opts={'nrow': nrow, 'title': '=== tile2vec (left half) ===== Epoch {}, #{} ===== PCA (right half) =========='.format(epoch+1, i+1)})

        # Save model
        if save_models:
            if rf_acc == np.max(np.array(rf_accs['tile2vec'])) or epoch == epochs-1:
                model_fn = os.path.join(model_dir, 'epoch{}.ckpt'.format(epoch+1))
                torch.save(cnn.state_dict(), model_fn)