In [1]:
import torch
import torch.nn as nn
import torchvision
from netCDF4 import Dataset
import numpy as np
# from torchsummary import summary

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!nvidia-smi

Fri Apr 22 01:25:11 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000004:05:00.0 Off |                    0 |
| N/A   39C    P0    37W / 300W |   2048MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  On   | 00000035:03:00.0 Off |                    0 |
| N/A   33C    P0    35W / 300W |   3072MiB / 16384MiB |      0%      Default |
|       

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
data_dir = '/home/jcurtis2/hackathon_data/'
wrf_filename = '%straining.nc' % data_dir
ncf = Dataset(wrf_filename, "r", format="NETCDF4")

data = ncf.variables['o3']
data = torch.from_numpy(np.array(data))
data = torchvision.transforms.Resize((157,157))(data)
# data = torch.unsqueeze(data,0)
data = torch.unsqueeze(data,1)
data = data.repeat(1,10,1,1,1)
print(data.shape) # num samples, channels, z, y, x

torch.Size([133, 10, 39, 157, 157])


In [4]:
# Hyperparameter settings
NUM_EPOCHS = 1
BATCH_SIZE = 4

In [17]:
class DoubleConv(nn.Module):
    
    def __init__(self, in_channels, out_channels, mid_channels=None, first_kernel=3, second_kernel=3):
        super().__init__()
        
        if mid_channels is None:
            mid_channels = out_channels
        
        self.first_down = nn.Sequential(
            nn.Conv3d(in_channels, mid_channels, first_kernel),
            nn.BatchNorm3d(mid_channels),
            nn.ReLU(),
            nn.Conv3d(mid_channels, out_channels, second_kernel),
            nn.BatchNorm3d(out_channels),
            nn.ReLU()
        )
        
    def forward(self, x):
        return self.first_down(x)
    
class Down(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool3d(2),
            DoubleConv(in_channels, out_channels)
        )
        
    def forward(self, x):
        return self.maxpool_conv(x)

In [18]:
class UNet(nn.Module):
    
    def __init__(self, n_channels, n_classes, bilinear=False):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        factor = 2 if bilinear else 1
        
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        
        self.down4 = Down(512, 1024 // factor)
        
#         self.up1 = Up(1024, 512 // factor, bilinear
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = x5
        return x
    
    def test_correct_size(self, x):
        return self.forward(x).shape

In [23]:
%%time
x = torch.randn(4, 10, 39, 157, 157).cuda()
# data = data.cuda()
inc = DoubleConv(10,16,first_kernel=7) # first_kernel=3: torch.Size([1, 64, 35, 153, 153])
inc.to(device)
x1 = inc(x)
print(x1.shape)
down1 = Down(16, 32)
down1.to(device)
x2 = down1(x1)
print(x2.shape)

torch.Size([4, 16, 31, 149, 149])
torch.Size([4, 32, 11, 70, 70])
CPU times: user 918 ms, sys: 30 ms, total: 948 ms
Wall time: 952 ms


In [None]:
## Train


for epoch in range(NUM_EPOCHS):
    
    running_loss = 0.0
#     for i, data in enumerate(trainloader, 0):
#         inputs, labels = data
        
#         optimizer.zero_grad()
        
#         outputs = net(inputs)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()
        
#         running_loss += loss.item
#         if i % 10 == 1:
#             print(running_loss)
#             running_loss = 0.0

print("Finished Training")