### Multiscale Residual Spatiotemporal Vision Transformer (MR-ST-ViT | PixFormer)

### Misael M. Morales, 2024
***

In [1]:
from main import *

hete = Heterogeneity()


------------------------------------------------------------
----------------------- VERSION INFO -----------------------
Torch version: 2.1.2+cu118 | Torch Built with CUDA? True
# Device(s) available: 1, Name(s): Quadro M6000 24GB
------------------------------------------------------------



In [None]:
hete.make_dataloaders()

In [None]:
hete.trainer()

In [None]:
hete.tester()

In [None]:
hete.plot_losses()

***
# END

In [6]:
class MyDataset(Dataset):
    '''
    Generate a custom dataset from .npz files
    (x) porosity, permeability, timesteps
    (y) pressure, saturation
    '''
    def __init__(self, file_paths, transform=None, norm_type:str='MinMax'):
        self.file_paths = file_paths
        self.transform  = transform
        self.tsteps     = 60
        self.x_channels = 3
        self.y_channels = 2
        self.orig_img   = 256
        self.half_img   = 64
        self.norm_type  = norm_type
        self.norm       = self.normalize

    def normalize(self, x):
        '''
        Normalize dataset based on user-defined scheme
        '''
        x_norm = np.zeros_like(x)
        error_msg = 'Invalid normalization scheme: {} | Select ["None", "MinMax", "ExtMinMax", "Standard"]'.format(self.norm_type)
        for i in range(x.shape[1]):
            if self.norm_type == 'MinMax':
                x_norm[:,i] = (x[:,i] - x[:,i].min()) / (x[:,i].max() - x[:,i].min())
            elif self.norm_type == 'Standard':
                x_norm[:,i] = (x[:,i] - x[:,i].mean()) / (x[:,i].std())
            elif self.norm_type == 'ExtMinMax':
                x_norm[:,i] = (x[:,i] - x[:,i].min()) / (x[:,i].max() - x[:,i].min()) * 2 - 1
            elif self.norm_type == 'None':
                x_norm = x
            else:
                raise ValueError(error_msg)
        return x_norm

    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        data   = np.load(self.file_paths[idx])
        poro   = np.tile(data['poro'], (self.tsteps, 1, 1, 1))
        perm   = np.tile(np.log10(data['perm']), (self.tsteps, 1, 1, 1))
        tstep  = np.tile(np.arange(1, self.tsteps+1).reshape(self.tsteps, 1, 1, 1), (1, 1, self.orig_img, self.orig_img))
        pres   = np.expand_dims(data['pres'], 1)
        sat    = np.expand_dims(data['sat'], 1)
        X_data = np.concatenate([poro, perm, tstep], axis=1).reshape(-1, self.x_channels, self.orig_img, self.orig_img)
        y_data = np.concatenate([pres, sat], axis=1).reshape(-1, self.y_channels, self.orig_img, self.orig_img)
        if self.transform:
            X_data_t = np.zeros((len(X_data), self.x_channels, self.half_img, self.half_img))
            y_data_t = np.zeros((len(y_data), self.y_channels, self.half_img, self.half_img))
            for i in range(len(X_data)):
                X_data_t[i] = self.transform(X_data[i].T)
                y_data_t[i] = self.transform(y_data[i].T)
            X_data, y_data = X_data_t, y_data_t
        x, y = torch.Tensor(self.norm(X_data)), torch.Tensor(self.norm(y_data))
        return x, y

In [7]:
class MyDataLoader(DataLoader):
    '''
    Generate a custom dataloader for dataset
    (train): x,y at timesteps 0-40
    (valid): x,y at timesteps 40-50
    (test):  x,y at timesteps 50-60
    '''
    def __init__(self, *args, mode:str=None, **kwargs):
        super(MyDataLoader, self).__init__(*args, num_workers=8, pin_memory=True, **kwargs)
        self.mode = mode

    def __iter__(self):
        for batch in super(MyDataLoader, self).__iter__():
            X_data, y_data = batch          # loads a batch of data with shate (b, t, c, h, w)
            if self.mode == 'train':        # _____TRAINING_____
                X_data = X_data[:, :40]     # x at timesteps 0-40
                y_data = y_data[:, :40]     # y at timesteps 0-40
            elif self.mode == 'valid':      # _____VALIDATION_____
                X_data = X_data[:, 40:50]   # x at timesteps 40-50
                y_data = y_data[:, 40:50]   # y at timesteps 40-50
            elif self.mode == 'test':       # ______TESTING______
                X_data = X_data[:, 50:]     # x at timesteps 50-60
                y_data = y_data[:, 50:]     # y at timesteps 50-60
            else:
                raise ValueError('Invalid mode: {} | select between "train", "valid" or "test"'.format(self.mode))
            X_data = X_data[:, ::X_data.shape[1]//10]
            y_data = y_data[:, ::y_data.shape[1]//10]
            X_data = X_data.reshape(-1, X_data.size(2), X_data.size(3), X_data.size(4)) # reshape to (b*t, c, h, w)
            y_data = y_data.reshape(-1, y_data.size(2), y_data.size(3), y_data.size(4)) # reshape to (b*t, c, h, w)
            yield X_data, y_data

In [8]:
file_names = os.listdir('Fdataset')
file_paths = [os.path.join('Fdataset', file_name) for file_name in file_names]

dataset = MyDataset(file_paths)
train_size = int(0.5*len(dataset))
valid_size = int(0.25*len(dataset))
test_size = int(0.25*len(dataset))

train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, test_size])
train_loader = MyDataLoader(train_dataset, batch_size=32, mode='train')
valid_loader = MyDataLoader(valid_dataset, batch_size=32, mode='valid')
test_loader = MyDataLoader(test_dataset, batch_size=32, mode='test')

In [12]:
model = PixFormer()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = CustomLoss()

train_loss, valid_loss, train_ssim, valid_ssim = [], [], [], []
best_val_loss, best_model = float('inf'), None
time0 = time.time()

for epoch in range(10):
    start_time = time.time()
    model.train()
    epoch_loss, epoch_ssim = 0.0, 0.0
    for i, (x,y) in enumerate(train_loader):
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_ssim += 1.0 - criterion.ssim(y_pred,y)
    train_loss.append(epoch_loss/(i+1))
    train_ssim.append(epoch_ssim/(i+1))
    

RuntimeError: DataLoader worker (pid(s) 15172, 10664, 17652, 12804, 9324, 14040, 23376, 13504) exited unexpectedly