3D-Vision Transformer model for trainning superband dataset, torch version 2.2.0+cu121

In [1]:
import torch
import numpy as np
import h5py
import torch.nn as nn
import random, os
import torch.optim as optim
from vit_pytorch.vit_3d import ViT
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def set_manual_seed(
    seed: int,
):
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    if torch.cuda.is_available():
        print("the used gpu: ", torch.cuda.device_count(), torch.cuda.is_available())
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

Reading superband main dataset 0724.hdf5, which could be download from https://www.scidb.cn/en/file?id=c2487ff51a6230ba2ff55f0b5a8bfe8f

In [2]:
import torch.utils.data as Data
seed = 3
set_manual_seed(seed)
h5pyfn='0724.hdf5'
fast5_data=h5py.File(h5pyfn, 'r')
reads = np.random.permutation(np.array(sorted(fast5_data.items()), dtype=object))

traindata=[]
validdata=[]
label = []
i=0
for read in reads:
    i=i+1
    if i>2000:
        validdata.append([torch.tensor(read[1]['sc_bands'][:], dtype=torch.float),
                torch.log(torch.tensor(read[1].attrs['Tc']+1, dtype=torch.float))])
    else:
        traindata.append([torch.tensor(read[1]['sc_bands'][:], dtype=torch.float),
                torch.log(torch.tensor(read[1].attrs['Tc']+1, dtype=torch.float))])

print(traindata[0][0].size(), len(traindata))
batch_size = 32
trainLoader = torch.utils.data.DataLoader(traindata, batch_size=batch_size,shuffle=True, num_workers=8)
validLoader = torch.utils.data.DataLoader(validdata, batch_size=batch_size, shuffle=False, num_workers=8)

the used gpu:  1 True
torch.Size([18, 32, 32, 32]) 2000


You can adjust the ViT parameters to get better trainning results. 

In [3]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.v = ViT(image_size = 32, frames = 32, 
            image_patch_size = 8,
            frame_patch_size = 8,
            channels = 18,    num_classes = 1,
            dim = 534,       depth = 3,       
            heads = 64,      mlp_dim = 1038,
            dropout = 0.107,  emb_dropout = 0.197
        )
    def forward(self, x):
        return self.v(x)
model = Net()
#model.load_state_dict(torch.load('model/sc0724.pth'))
model.to(device)

for param_name, param in model.named_parameters():
    if ('weight' in param_name) or ('bias' in param_name):
        if len(param.size()) >= 2:
            torch.nn.init.kaiming_normal_(param)
        else:
            torch.nn.init.normal_(param)

In [4]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0    
        for inputs, labels in train_loader:
            input, label = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(input).squeeze()
            loss = criterion(outputs, label)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')
    return model

def evaluate_model(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            input, label = inputs.to(device), labels.to(device)
            outputs = model(input).squeeze()
            loss = criterion(outputs, label)
            val_loss += loss.item()
    print(f'Validation Loss: {val_loss / len(val_loader):.4f}')
    return val_loss / len(val_loader)

Maybe need to run following program again and again to finish the trainning

In [5]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
for i in range(10):
    model=train_model(model, trainLoader, criterion, optimizer, num_epochs=10)
    val_loss = evaluate_model(model, validLoader, criterion)
PATH = 'model/sc0724.pth'
torch.save(model.state_dict(), PATH)

Epoch [1/10], Loss: 1.2587
Epoch [2/10], Loss: 1.2202
Epoch [3/10], Loss: 1.0933
Epoch [4/10], Loss: 1.0640
Epoch [5/10], Loss: 1.0537
Epoch [6/10], Loss: 1.0495
Epoch [7/10], Loss: 1.0477
Epoch [8/10], Loss: 1.0437
Epoch [9/10], Loss: 1.0434
Epoch [10/10], Loss: 1.0394
Validation Loss: 1.0627
Validation Loss: 1.0627
Epoch [1/10], Loss: 1.0378
Epoch [2/10], Loss: 1.0362
Epoch [3/10], Loss: 1.0359
Epoch [4/10], Loss: 1.0309
Epoch [5/10], Loss: 1.0297
Epoch [6/10], Loss: 1.0053
Epoch [7/10], Loss: 0.9785
Epoch [8/10], Loss: 0.9679
Epoch [9/10], Loss: 0.9550
Epoch [10/10], Loss: 0.9373
Validation Loss: 0.9938
Validation Loss: 0.9938
Epoch [1/10], Loss: 0.9186
Epoch [2/10], Loss: 0.9085
Epoch [3/10], Loss: 0.8851
Epoch [4/10], Loss: 0.8813
Epoch [5/10], Loss: 0.8649
Epoch [6/10], Loss: 0.8473
Epoch [7/10], Loss: 0.8309
Epoch [8/10], Loss: 0.8185
Epoch [9/10], Loss: 0.7977
Epoch [10/10], Loss: 0.7846
Validation Loss: 0.7553
Validation Loss: 0.7553
Epoch [1/10], Loss: 0.7737
Epoch [2/10], Lo