In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.utils.data import sampler
import torchvision.datasets as dset
import torchvision.transforms as T
import os
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from unet_parts import *
from unet_model3 import UNet


In [3]:
device =torch.device('cuda')

In [4]:
device

device(type='cuda')

In [5]:
stldata = np.load('C:/Users/onebean5/Desktop/fluid/data/stlsave.npy')
xdata = np.load('C:/Users/onebean5/Desktop/fluid/data/xsave.npy').astype(int)
vdata = np.load('C:/Users/onebean5/Desktop/fluid/data/vsave.npy')
ydata = np.load('C:/Users/onebean5/Desktop/fluid/data/ysave.npy')
stldata = torch.tensor(stldata)
xdata = torch.tensor(xdata)
vdata = torch.tensor(vdata)
ydata = torch.tensor(ydata)

In [6]:
print(stldata.shape)
print(xdata.shape)
print(vdata.shape)
print(ydata.shape)

torch.Size([12, 1, 300, 300])
torch.Size([75, 1, 1, 1])
torch.Size([75, 1, 1, 1])
torch.Size([75, 1, 300, 300])


In [7]:
stldata = (stldata-150)/150 #normalize to range [-1, 1]

In [8]:
print(torch.min(stldata), torch.max(stldata))

tensor(-1.) tensor(1.)


In [9]:
def masked_mse(x,pred,y):
  mse=torch.square(y-pred)
  mse[x!=torch.max(x)]=0
  return mse

def train(model,
          device,
          dataset_n,
          epochs: int = 5,
          batch_size: int = 1,
          learning_rate: float = 1e-5,
          val_percent: float = 0.1,
          weight_decay: float = 1e-8,
          momentum: float = 0.999,
          amp: bool = False,
          gradient_clipping: float = 1.0,):
  global loss_save
  
  n_val = int(dataset_n * val_percent)
  n_train = dataset_n - n_val
  train_i, val_i = random_split(torch.arange(dataset_n), [n_train, n_val])
  loader_args = dict(batch_size=10, num_workers=os.cpu_count(), pin_memory=True)
  train_i_loader = DataLoader(train_i, shuffle=True, **loader_args)
  val_i_loader = DataLoader(val_i, shuffle=False, drop_last=True, **loader_args)


  optimizer = optim.RMSprop(model.parameters(),
                            lr=learning_rate,
                            weight_decay=weight_decay,
                            momentum=momentum,
                            foreach=True)

  grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
  model = model.to(device)
  for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0

    for x in train_i_loader:
      x_data = torch.squeeze(stldata[xdata[x,:,0,0]],1)
      v_data = vdata[x]
      y_data = ydata[x]

      x_data = x_data.to(device)
      v_data = v_data.to(device)
      y_data = y_data.to(device)
      with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
        pred = model(x_data,v_data)
        loss = torch.mean(masked_mse(x_data, pred, y_data))

      optimizer.zero_grad(set_to_none=True)
      grad_scaler.scale(loss).backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
      grad_scaler.step(optimizer)
      grad_scaler.update()


      epoch_loss += loss.item()
    loss_save.append(epoch_loss)
    if epoch %100 == 0:
      print(f'epochs: {epoch}, loss: {epoch_loss}')

In [10]:
unet = UNet(1,1)


In [12]:
loss_save = []

In [13]:
train(unet,device,75, batch_size=32,epochs = 10, learning_rate = 1e-6)

In [14]:
print(loss_save)

[1914.705062866211, 1935.2310028076172, 1938.7019500732422, 1925.0524139404297, 1919.752426147461, 1934.7607727050781, 1884.6413269042969, 1902.2943420410156, 1882.3775329589844, 1858.8748626708984]


Test Data & Visualize

In [219]:
def test_diff(model, x_data, v_data, y_data):
  model = model.to(torch.device('cpu'))
  with torch.no_grad():
    pred = model(x_data,v_data)
    mse = masked_mse(x_data, pred.data.cpu().numpy(),v_data)
  return pred, mse


In [220]:
dataidx = 20

x_test = torch.squeeze(stldata[xdata[dataidx,:,0,0]],1).reshape((1,1,300,300))
v_test = vdata[dataidx].reshape((1,1,1,1))
y_test = ydata[dataidx].reshape((1,1,300,300))

torch.Size([2, 64, 75, 75])

In [65]:
pred, mse = test_diff(unet,x_test,v_test,y_test)

In [66]:
f, axes = plt.subplots(1,3)
f.set_size_inches((20, 5))
plt.subplots_adjust(wspace = 0.3, hspace = 0.3)
sctt_1 = axes[0].pcolor(y_test[0][0],cmap='rainbow')
sctt_2 = axes[1].pcolor(pred[0][0],cmap='rainbow')
sctt_3 = axes[2].pcolor(mse[0][0],cmap='rainbow')