In [None]:
# default_exp utils

In [None]:
# EXPORT
import torch
import torch.nn as nn
from torch_lr_finder import LRFinder
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.gridspec as gridspec
from torchvision.utils import make_grid, save_image
import matplotlib.animation as manimati
from matplotlib import animation, rc
from IPython.display import HTML
import pickle
from numpy.linalg import svd

## Util functions for Training Models

In [None]:
# EXPORT
def silu(input):
    '''
    Applies the Sigmoid Linear Unit (SiLU) function element-wise:

        SiLU(x) = x * sigmoid(x)
    '''
    return input * torch.sigmoid(input) # use torch.sigmoid to make sure that we created the most efficient implemetation based on builtin PyTorch functions

class SiLU(nn.Module):
    '''
    Applies the Sigmoid Linear Unit (SiLU) function element-wise:

        SiLU(x) = x * sigmoid(x)

    Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input

    References:
        -  Related paper:
        https://arxiv.org/pdf/1606.08415.pdf

    Examples:
        >>> m = silu()
        >>> input = torch.randn(2)
        >>> output = m(input)

    '''
    def __init__(self):
        super().__init__() 

    def forward(self, x):
        return silu(x) 

def create_opt(lr,model):
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    return opt

def create_one_cycle(opt,max_lr,epochs,dataLoader):
    return torch.optim.lr_scheduler.OneCycleLR(
        optimizer=opt,
        max_lr=max_lr,
        epochs=epochs,
        steps_per_epoch=len(dataLoader))

def find_lr(model,opt,loss_func,device,dataLoader):
    lr_finder = LRFinder(model=model, optimizer=opt, criterion=loss_func, device=device)
    lr_finder.range_test(dataLoader, end_lr=100, num_iter=200)
    lr_finder.plot()
    # reset model & opt to their original weights
    lr_finder.reset()
    
def printNumModelParams(model):
    layers_req_grad = 0
    tot_layers = 0

    params_req_grad = 0
    tot_params = 0

    for param in model.named_parameters():
        #print(param[0])
        if (param[1].requires_grad):
            layers_req_grad += 1
            params_req_grad += param[1].nelement()
        tot_layers += 1
        tot_params += param[1].nelement()
    print("{0:,} layers require gradients (unfrozen) out of {1:,} layers".format(layers_req_grad, tot_layers))
    print("{0:,} parameters require gradients (unfrozen) out of {1:,} parameters".format(params_req_grad, tot_params))
    
def calcAccuracy(preds, labels):
    softedPreds = torch.softmax(preds,dim=1)
    classPreds = softedPreds.argmax(dim=1)
    totCorrect = (classPreds == labels).sum().item()
    totNum = labels.nelement()
    return totCorrect/totNum

def rmse(preds, labels):
    d = (preds - labels)**2
    d = d.mean()
    try:
        r = d.sqrt()
    except:
        r = np.sqrt(d)
    return r

def writeMessage(msg, versionName):
    # Write to file.
    print(msg)
    myFile = open(versionName+".txt", "a")
    myFile.write(msg)
    myFile.write("\n")
    myFile.close()
    
def plotSample(X):
        plt.figure(figsize=(20,20))
        
        plt.subplot(211)
        title = 'Channel 0'
        plt.title(title)
        plt.imshow(X[0])
        plt.colorbar()
        
        plt.subplot(212)
        title = 'Channel 1'
        plt.title(title)
        plt.imshow(X[1])
        plt.colorbar()
    
def plotSampleWpredictionByChannel(sample, prediction):
    fig, axs = plt.subplots(2, 2)
    fig.set_size_inches(20,20, forward=True)

    axs[0, 0].imshow(sample[0])
    axs[0, 0].set_title('Simulated Channel 0')
    axs[0, 1].imshow(prediction[0])
    axs[0, 1].set_title('Predicted Channel 0]')
    axs[1, 0].imshow(sample[1])
    axs[1, 0].set_title('Simulated Channel 1')
    axs[1, 1].imshow(prediction[1])
    axs[1, 1].set_title('Predicted Channel 1')
    #plt.subplots_adjust(wspace=0, hspace=0)
    # for ax in axs.flat:
    #     ax.set(xlabel='x-label', ylabel='y-label')

    # # Hide x labels and tick labels for top plots and y ticks for right plots.
    # for ax in axs.flat:
    #     ax.label_outer()

def plotSampleWprediction(sample,prediction):
    plt.figure(figsize=(20,20))
    A = np.vstack([sample[0], sample[1]])
    B = np.vstack([prediction[0], prediction[1]])
    C = np.hstack([A,B])
    plt.axis('off')
    plt.imshow(C)
    plt.colorbar()

    
def curl(X,device='cpu'):
    f1 = X[:,0,:,:]
    f2 = X[:,1,:,:]
    df1_dy = f1[:,1:,:] - f1[:,:-1,:]
    df1_dy = torch.cat([df1_dy,torch.zeros((df1_dy.shape[0],1,f1.shape[2])).to(device)], axis=1) 
    df2_dx = f2[:,:,1:] - f2[:,:,:-1]
    df2_dx = torch.cat([df2_dx,torch.zeros((f2.shape[0],f2.shape[1],1)).to(device)], axis=2)
    c = df1_dy - df2_dx
    c = c[:,None,:,:]
    return c

def jacobian(X,device='cpu'):
    f1 = X[:,0,:,:]
    f2 = X[:,1,:,:]
    
    df1_dx = f1[:,:,1:] - f1[:,:,:-1]
    df1_dx = torch.cat([df1_dx,torch.zeros((f2.shape[0],f2.shape[1],1)).to(device)], axis=2)
    
    df1_dy = f1[:,1:,:] - f1[:,:-1,:]
    df1_dy = torch.cat([df1_dy,torch.zeros((df1_dy.shape[0],1,f1.shape[2])).to(device)], axis=1) 
    
    df2_dx = f2[:,:,1:] - f2[:,:,:-1]
    df2_dx = torch.cat([df2_dx,torch.zeros((f2.shape[0],f2.shape[1],1)).to(device)], axis=2)

    df2_dy = f2[:,1:,:] - f2[:,:-1,:]
    df2_dy = torch.cat([df2_dy,torch.zeros((df1_dy.shape[0],1,f1.shape[2])).to(device)], axis=1) 
  
    return torch.stack([df1_dx, df1_dy, df2_dx, df2_dy], axis=1)

# http://farside.ph.utexas.edu/teaching/336L/Fluidhtml/node69.html
# When creating the stream function, the second channel of X is not going to be used. 
# It's there so we don't have to change the AE model code. 
def stream2uv(X,device='cpu'):
    u = X[:,0,1:,:] - X[:,0,:-1,:]
    w = torch.unsqueeze(u[:,-1,:],axis=1)
    u = torch.cat([u,w],axis=1)
    v = X[:,0,:,1:] - X[:,0,:,:-1]
    w = torch.unsqueeze(u[:,:,-1],axis=2)
    v = torch.cat([v,w],axis=2)
    return torch.stack([u,v], axis=1)


def show(img,flip=False):
    npimg = img.numpy()
    if flip:
        npimg = np.flip(npimg)
    plt.figure(figsize=(40,20))
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    
def convertSimToImage(X): 
    # X = [frames,channels,h,w]
    mid = 128
    M = 255
    mx = X.max()
    mn = X.min()
    X = (X - mn)/(mx - mn)

    #C = np.uint8(M*B)
    C = (M*X).type(torch.uint8)

    if C.shape[1] == 2:
        out_shape = C.shape
        Xrgb = torch.zeros((out_shape[0],3,out_shape[2],out_shape[3])).type(torch.uint8)
        filler = mid*torch.ones(C.shape[2:]).type(torch.uint8)
        filler = filler.unsqueeze(axis=0)
        for idx, frame in enumerate(C):
            #Xrgb[idx] = torch.cat([frame[0].unsqueeze(axis=0),filler,frame[1].unsqueeze(axis=0)],axis=0)
            Xrgb[idx] = torch.cat([frame,filler],axis=0)
            #Xrgb[idx] = torch.cat([filler,frame],axis=0)
    else:
        Xrgb = C
    return Xrgb


def create_movie(Xrgb,outfile='sim.mp4',title='surrogate            simulation'):
    ti = 0
    u_mx = 255 #np.max(np.abs(Xrgb))
    fig = plt.figure()
    ax = fig.add_subplot(111)
    plt.title(title)
    cmap = plt.cm.ocean
    img = ax.imshow(np.transpose(Xrgb[0], (1,2,0)), cmap=cmap, vmin=0, vmax=u_mx)
    #plt.show()
    
    # initialization function: plot the background of each frame
    def init():
        img = ax.imshow(np.transpose(np.flip(Xrgb[0]), (1,2,0)), cmap=cmap, vmin=0, vmax=u_mx)
        return (fig,)

    # animation function. This is called sequentially
    def animate(i):
        img = ax.imshow(np.transpose(np.flip(Xrgb[i]), (1,2,0)), cmap=cmap, vmin=0, vmax=u_mx)
        return (fig,)


    # call the animator. blit=True means only re-draw the parts that have changed.
    anim = animation.FuncAnimation(fig, animate, init_func=init,
                                   frames=len(Xrgb), interval=20, blit=True)
    anim.save(outfile, fps=30, extra_args=['-vcodec', 'libx264'])
    


def create_1_channel_movie(im,outfile='sim.mp4',title='surrogate            simulation'):
    ti = 0
    u_mx = 255 #np.max(np.abs(Xrgb))
    fig = plt.figure()
    ax = fig.add_subplot(111)
    plt.title(title)
    cmap = plt.cm.RdYlBu
    img = ax.imshow(im[0].squeeze(), cmap=cmap, vmin=0, vmax=u_mx)
    #plt.show()
    
    # initialization function: plot the background of each frame
    def init():
        img = ax.imshow(im[0].squeeze(), cmap=cmap, vmin=0, vmax=u_mx)
        return (fig,)

    # animation function. This is called sequentially
    def animate(i):
        img = ax.imshow(im[i].squeeze(), cmap=cmap, vmin=0, vmax=u_mx)
        return (fig,)


    # call the animator. blit=True means only re-draw the parts that have changed.
    anim = animation.FuncAnimation(fig, animate, init_func=init,
                                   frames=len(im), interval=20, blit=True)
    anim.save(outfile, fps=30, extra_args=['-vcodec', 'libx264'])
    
def make_PNNL_movie(surr, real, title='surrogate                    simulation', outfile='sim.mp4'):
    surr = surr.squeeze()
    surr = np.array(list(map(np.rot90,surr)))
    real = real.squeeze()
    real = np.array(list(map(np.rot90,real)))
    out = np.concatenate([surr,real],axis=2)
    im = convertSimToImage(torch.tensor(out))
    create_1_channel_movie(im,outfile=outfile,title=title)

## Easy pickle

In [None]:
# EXPORT
def pkl_save(D,fn):
    with open(fn,'wb') as fid:
        pickle.dump(D,fid)

def pkl_load(fn):
    with open(fn,'rb') as fid:
        D = pickle.load(fid)
        return D

## SVD helper functions

In [None]:
# EXPORT
def computeSpatialandTimePOD(data,simLen,doPlot=False):
    # data should of size numFrames x vecLength
    # the columns of spatialVecs are the spatial PODs
    # the rows of timeVecs are the temporal PODs
    numSamps = len(data)
    data = data.reshape(numSamps,-1)
    spatialVecs,S,vh = svd(data.T,full_matrices=False)
    s_cum = np.cumsum(S/np.sum(S))
    if doPlot:
        plt.plot(s_cum)
        plt.show()
    numSims = vh.shape[1]//simLen
    tt = vh.reshape(numSims*len(vh),simLen)
    timeVecs,b,c = svd(tt.T,full_matrices=False)
    s_cum = np.cumsum(b/np.sum(b))
    if doPlot:
        plt.plot(s_cum)
        plt.show()
    return spatialVecs, S, timeVecs.T

def reconFrame(u,frame,numComp=512):
    # u is from u,s,vh = svd(data)
    # frame = channels x height x width
    x = frame.reshape(1,frame.size)
    coeffs = (x@u[:,:numComp]).flatten()
    R = np.zeros(x.shape)
    for idx, c in enumerate(coeffs):
        R += c*u[:,idx]
    R = R.reshape(frame.shape)
    return R, coeffs

## Deep Fluid's code

In [None]:
# This curl makes no sense to me. I think the derivative should be taken across channels
# def curl(x, data_format='NHWC'):
#     if data_format == 'NCHW': x = nchw_to_nhwc(x)

#     u = x[:,1:,:,0] - x[:,:-1,:,0] # ds/dy
#     v = x[:,:,:-1,0] - x[:,:,1:,0] # -ds/dx,
#     u = tf.concat([u, tf.expand_dims(u[:,-1,:], axis=1)], axis=1)
#     v = tf.concat([v, tf.expand_dims(v[:,:,-1], axis=2)], axis=2)
#     c = tf.stack([u,v], axis=-1)

#     if data_format == 'NCHW': c = nhwc_to_nchw(c)
#     return c

# def jacobian(x, data_format='NHCW'):
#     if data_format == 'NCHW':
#         x = nchw_to_nhwc(x)

#     dudx = x[:,:,1:,0] - x[:,:,:-1,0]
#     dudy = x[:,1:,:,0] - x[:,:-1,:,0]
#     dvdx = x[:,:,1:,1] - x[:,:,:-1,1]
#     dvdy = x[:,1:,:,1] - x[:,:-1,:,1]
    
#     dudx = tf.concat([dudx,tf.expand_dims(dudx[:,:,-1], axis=2)], axis=2)
#     dvdx = tf.concat([dvdx,tf.expand_dims(dvdx[:,:,-1], axis=2)], axis=2)
#     dudy = tf.concat([dudy,tf.expand_dims(dudy[:,-1,:], axis=1)], axis=1)
#     dvdy = tf.concat([dvdy,tf.expand_dims(dvdy[:,-1,:], axis=1)], axis=1)

#     j = tf.stack([dudx,dudy,dvdx,dvdy], axis=-1)
#     w = tf.expand_dims(dvdx - dudy, axis=-1) # vorticity (for visualization)

#     if data_format == 'NCHW':
#         j = nhwc_to_nchw(j)
#         w = nhwc_to_nchw(w)
#     return j, w

## Testing and How to use

In [None]:
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import sklearn
from sklearn.datasets import make_classification
import numpy as np
import os

In [None]:
X, y = make_classification(n_samples=100, n_features=28*28, n_informative=400, n_redundant=2, n_repeated=0, n_classes=2)
X = X.astype('float32')
X = torch.tensor(X).reshape(100,1,28,28).type(torch.float32)
y = torch.tensor(y)
y.dtype

torch.int64

In [None]:
class MyDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = torch.LongTensor(targets)
        self.transform = transform

    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]

        if self.transform:
            x = self.transform(x)

        return x, y

    def __len__(self):
        return len(self.data)


In [None]:
dataset = MyDataset(X,y)
dataset[10]

(tensor([[[ 4.5080e-01,  5.8122e-01, -6.5560e-01, -3.9072e-01,  2.0177e-01,
            7.6985e-01,  1.1727e+00,  1.8524e+01,  2.5810e-01, -1.6648e+01,
           -1.6201e+01,  1.9995e-01,  2.5225e+01,  1.7693e-01, -8.9749e-01,
            4.2682e-01, -1.8444e+00,  2.1024e+00,  3.5770e-01,  2.3846e-01,
           -3.0043e+00, -6.1409e-01, -5.6685e+00,  4.5927e+00, -1.3857e+01,
            6.3421e+00, -1.7266e+01,  5.3794e+00],
          [ 1.1012e+01,  1.2134e+00,  9.6571e-01,  1.4821e+01,  1.2435e+00,
            1.0125e-02, -7.6497e-01, -1.6214e+00,  1.2660e+01, -8.7501e+00,
            1.8245e+00,  8.6400e+00,  1.0887e+01, -1.6755e+00, -2.7765e-01,
           -2.1815e+01, -2.8381e+00,  1.5722e+01,  1.0457e+01,  2.1155e+00,
            6.7007e+00,  8.0044e+00, -1.7037e+00, -1.3783e+01,  6.7658e-01,
            5.4921e-01, -1.5484e+00, -1.4696e+01],
          [-7.7276e+00,  1.0670e+00, -2.6884e+01,  6.9988e-02,  7.3225e-01,
           -4.9642e+00, -2.3999e+01,  7.7308e+00,  3.1441e+00,

In [None]:
dataLoader = DataLoader(dataset,batch_size=10)

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.drop_out = nn.Dropout()
        self.fc1 = nn.Linear(7 * 7 * 64, 1000)
        self.fc2 = nn.Linear(1000, 10)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.drop_out(out)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

In [None]:
model = ConvNet()

In [None]:
max_lr = 1e-3
epochs = 100
opt = create_opt(max_lr,model)

In [None]:
opt_sched = create_one_cycle(opt,max_lr,epochs,dataLoader)

In [None]:
batch = next(iter(dataLoader))
batch[0].shape

torch.Size([10, 1, 28, 28])

In [None]:
out = model(batch[0])
out.shape

torch.Size([10, 10])

In [None]:
loss_func = torch.nn.CrossEntropyLoss()
loss_func

CrossEntropyLoss()

In [None]:
loss_func(out,batch[1])

tensor(3.2888, grad_fn=<NllLossBackward>)

In [None]:
printNumModelParams(model)

8 layers require gradients (unfrozen) out of 8 layers
3,199,106 parameters require gradients (unfrozen) out of 3,199,106 parameters


In [None]:
bz = 8
h = 4
w = 3
c = 2
x = torch.rand(bz,c,h,w )
x.shape

torch.Size([8, 2, 4, 3])

In [None]:
a= curl(x)
a.shape

torch.Size([8, 1, 4, 3])

In [None]:
J = jacobian(x)
print(J.shape)


torch.Size([8, 4, 4, 3])


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [None]:
x = x.to(device)

In [None]:
a = curl(x,device)
a.shape

torch.Size([8, 1, 4, 3])

In [None]:
J = jacobian(x,device)
print(J.shape)

torch.Size([8, 4, 4, 3])


In [None]:
stream2uv(X).shape

torch.Size([100, 2, 28, 28])

In [None]:
stream2uv(X,device).shape

torch.Size([100, 2, 28, 28])

In [None]:
stream2uv(X) - stream2uv(X,device)

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
        