# Deep learning-based automated rock classification via high-resolution drone-captured core sample imagery
***
### Domenico M. Crisafulli, Misael M. Morales, and Carlos Torres-Verdin
#### The University of Texas at Austin, 2024
***

## Build and Train NN-classifier
| Class             | OLD   | New   |
| ---               | ---   | ---   |
| Background        | 0     | 0     |
| Sandstone type 1  | 1     | 2     |
| Shaly Rock        | 2     | 3     |
| Sandstone type 2  | 3     | 4     |
| Carbonate         | 4     | 5     |
| Shale             | 5     | 6     |
| Sandstone type 3  | 6     | 7     |
| Box               | 10    | 1     |

***
https://www.akshaymakes.com/blogs/vision-transformer

MaskFormer for image segmentation

https://huggingface.co/docs/transformers/en/model_doc/maskformer

In [1]:
from main import *
device = check_torch()

------------------------------------------------------------
----------------------- VERSION INFO -----------------------
Torch version: 2.3.1+cu121 | Torch Built with CUDA? True
# Device(s) available: 1, Name(s): NVIDIA GeForce RTX 3080
------------------------------------------------------------


In [2]:
class RockClassification(nn.Module):
    def __init__(self):
        super(RockClassification, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        
        self.convt3 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.convt2 = nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3, stride=1, padding=1)
        self.convt1 = nn.ConvTranspose2d(in_channels=8, out_channels=1, kernel_size=3, stride=1, padding=1)

        self.bn0 = nn.BatchNorm2d(1)
        self.bn1 = nn.BatchNorm2d(8)
        self.bn2 = nn.BatchNorm2d(16)
        self.bn3 = nn.BatchNorm2d(32)

        self.pool = nn.MaxPool2d(kernel_size=2, padding=0)
        self.upsm = nn.Upsample(scale_factor=2, mode='nearest')

        self.relu = nn.ReLU()
        self.soft = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.pool(self.relu(self.bn1(self.conv1(x))))
        x = self.pool(self.relu(self.bn2(self.conv2(x))))
        x = self.pool(self.relu(self.bn3(self.conv3(x))))

        x = self.relu(self.bn2(self.convt3(self.upsm(x))))
        x = self.relu(self.bn1(self.convt2(self.upsm(x))))
        x = self.relu(self.bn0(self.convt1(self.upsm(x))))

        return torch.round(x) #self.soft(x)

In [3]:
class CustomDataset(Dataset):
    def __init__(self, input_dir, output_dir, transform=None):
        self.input_dir = input_dir
        self.input_filenames = os.listdir(input_dir)
        self.input_filenames.sort()
        self.output_dir = output_dir
        self.output_filenames = os.listdir(output_dir)
        self.output_filenames.sort()
        self.remap_dict = {0:0, 1:2, 2:3, 3:4, 4:5, 5:6, 6:7, 10:1}
        self.transform = transform

    def __len__(self):
        return len(self.input_filenames)

    def __getitem__(self, idx):
        input_npy_path  = os.path.join(self.input_dir, self.input_filenames[idx])
        output_npy_path = os.path.join(self.output_dir, self.output_filenames[idx])
        x = np.load(input_npy_path)/255
        input_img = torch.tensor(x).unsqueeze(0).nan_to_num(0).type(torch.float32)
        #output_img = np.vectorize(self.remap_dict.get)(np.load(output_npy_path))
        output_img = np.load(output_npy_path)
        output_img = torch.tensor(output_img).unsqueeze(0).nan_to_num(0).type(torch.int32)
        if self.transform is not None:
            input_img, output_img = self.transform(input_img), self.transform(output_img)
        return input_img, output_img

In [4]:
class PatchTransform:
    def __init__(self, patch_w:int=6, patch_h:int=8):
        self.patch_w = patch_w
        self.patch_h = patch_h

    def __call__(self, img):
        sw = img.shape[-2] // self.patch_w
        sh = img.shape[-1] // self.patch_h
        patches = img.unfold(-2, sw, sw).unfold(-2, sh, sh).reshape(-1, 1, sw, sh)
        return patches

In [5]:
class PatchNonzeroFilter:
    def __init__(self, background_class:int=0, verbose:bool=False):
        self.background = background_class
        self.verbose = verbose
    
    def __call__(self, ximg, yimg):
        xmask = torch.sum(ximg, dim=(-3,-2,-1)) != self.background
        xfilt = ximg[xmask]
        ymask = torch.sum(yimg, dim=(-3,-2,-1)) != self.background
        yfilt = yimg[ymask]
        if xfilt.shape[0] != yfilt.shape[0]:
            print('Warning: Input and Output shapes do not match | Filtering with smaller mask...') if self.verbose else None
            mask = ymask if xfilt.shape[0] > yfilt.shape[0] else xmask
        else:
            mask = xmask
        return ximg[mask], yimg[mask], mask

In [6]:
class PatchReconstruct:
    def __init__(self, patch_w:int=6, patch_h:int=8):
        self.patch_w = patch_w
        self.patch_h = patch_h

    def __call__(self, img, mask):
        s = img.size(-1)
        b = mask.size(0)
        def repatch(_):
            _ = torch.permute(_, (0,2,1,3,4))
            _ = torch.reshape(_, (b, 1, self.patch_w, self.patch_h, s, s))
            _ = torch.permute(_, (0,1,2,4,3,5))
            _ = torch.reshape(_, (b, 1, s*self.patch_w, s*self.patch_h))
            return _
        xout = torch.zeros((b, self.patch_w*self.patch_h, 1, s, s), dtype=img.dtype)
        xout[mask] = img
        xout = repatch(xout)
        return xout

In [7]:
patch_transform = PatchTransform(patch_w=6, patch_h=8)
dataset = CustomDataset(input_dir='data/x_images', output_dir='data/y_images', transform=patch_transform)

train_percent = 0.8
train, test   = random_split(dataset, [int(train_percent*len(dataset)), len(dataset)-int(train_percent*len(dataset))])
train, valid  = random_split(train, [int(train_percent*len(train)), len(train)-int(train_percent*len(train))])

batch_size = 4
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test,  batch_size=batch_size, shuffle=True)
all_loader   = DataLoader(dataset, batch_size=8, shuffle=False)

In [8]:
# Training Loop
model     = RockClassification().to(device)
nparams   = sum(p.numel() for p in model.parameters())
print('# parameters: {:,}'.format(nparams))

criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-8)

epochs, monitor = 301, 10
train_loss, valid_loss = [], []
for epoch in range(epochs):
    # training
    epoch_train_loss = []
    model.train()
    for i, (x_train, y_train) in enumerate(train_loader):
        xf, yf, mask = PatchNonzeroFilter()(x_train, y_train)
        xf, yf = xf.to(device), yf.to(device)
        optimizer.zero_grad()
        yhat = model(xf)
        loss = criterion(yhat, yf)
        loss.backward()
        optimizer.step()
        epoch_train_loss.append(loss.item())
    train_loss.append(np.mean(epoch_train_loss))
    # validation
    epoch_valid_loss = []
    model.eval()
    with torch.no_grad():
        for j, (x_valid, y_valid) in enumerate(valid_loader):
            xvf, yvf, vmask = PatchNonzeroFilter()(x_valid, y_valid)
            xvf, yvf = xvf.to(device), yvf.to(device)
            yvhat = model(xvf)
            vloss = criterion(yvhat, yvf)
            epoch_valid_loss.append(vloss.item())
    valid_loss.append(np.mean(epoch_valid_loss))
    # monitor
    if epoch % monitor == 0:
        print('Epoch: {} | Loss: {:.4f} | Valid Loss: {:.4f}'.format(epoch, train_loss[-1], valid_loss[-1]))

losses = pd.DataFrame({'train': train_loss, 'valid': valid_loss})
losses.to_csv('losses.csv')
torch.save(model.state_dict(), 'model.pth')

# parameters: 11,859
Epoch: 0 | Loss: 0.4326 | Valid Loss: 0.5911
Epoch: 10 | Loss: 0.4896 | Valid Loss: 0.6300
Epoch: 20 | Loss: 0.4572 | Valid Loss: 0.6030
Epoch: 30 | Loss: 0.4691 | Valid Loss: 0.5880
Epoch: 40 | Loss: 0.4622 | Valid Loss: 0.6563
Epoch: 50 | Loss: 0.5019 | Valid Loss: 0.6161
Epoch: 60 | Loss: 0.4941 | Valid Loss: 0.5862


KeyboardInterrupt: 

In [20]:
k = 0
for i, (x,y) in tqdm(enumerate(all_loader)):
    print(x.shape, y.shape)
    xf, yf, mask = PatchNonzeroFilter()(x,y)
    #xf, yf, mask = xf.to(device), yf.to(device), mask.to(device)
    yp = model(xf.to(device)).detach().cpu()
    yu = PatchReconstruct()(yp, mask)
    for j in range(yu.shape[0]):
        ypi = yu[j].squeeze().cpu().numpy()
        np.save('data/y_predictions/pimg_{}.npy'.format(k), ypi)
        k += 1
    break

0it [00:00, ?it/s]

torch.Size([8, 48, 1, 504, 504]) torch.Size([8, 48, 1, 504, 504])


0it [00:28, ?it/s]


KeyboardInterrupt: 

In [None]:
# Inference Loop (and save)
k = 0
for i, (x,y) in tqdm(enumerate(all_loader)):
    xfilt, yfilt, mask = PatchNonzeroFilter()(x,y)
    ypred = model(xfilt)
    yunfilt = PatchReconstruct()(ypred, mask)
    for j in range(yunfilt.shape[0]):
        ypred_img = yunfilt[j].squeeze().cpu().numpy()
        np.save('data/y_predictions/pimg_{}.npy'.format(k), ypred_img)
        k += 1

***
# END