In [1]:
import torch
import torch.nn as nn
import SLDLoader.torch
import numpy as np
import random
import os

from tqdm import tqdm


class ModifiedLightweight3DCNN(nn.Module):
    def __init__(self):
        super(ModifiedLightweight3DCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=30, out_channels=8, kernel_size=(3, 3), stride=1, padding=1)
        self.res2 = self._make_layer(block_count=3, in_channels=8, out_channels=8)  # Modified in_channels
        self.res3 = self._make_layer(block_count=4, in_channels=8, out_channels=64)  # Modified in_channels
        self.res4 = self._make_layer(block_count=6, in_channels=64, out_channels=128)
        self.res5 = self._make_layer(block_count=3, in_channels=128, out_channels=256)
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, 1)

    def _make_layer(self, block_count, in_channels, out_channels):
        layers = []
        for _ in range(block_count):
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=1, padding=1))
            layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.ReLU(inplace=True))
            in_channels = out_channels  # Update in_channels for the next layer
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.res4(x)
        x = self.res5(x)
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = torch.sigmoid(x)
        return x


def init_seed(seed):
    torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [8]:
data_folder = r'D:\2023-2024\Research\Align\SLD\SLD\output\skeletons'
sign_list = sorted(os.listdir(data_folder))
hightlight_sign = sign_list[0]
dataset = SLDLoader.torch.SLD(data_folder,30,
                                      32,42)
data_loader = torch.utils.data.DataLoader(
                dataset=dataset.get_generator(hightlight_sign,num_data=100),
                batch_size=32,
                num_workers=0,
                drop_last=True,pin_memory=True,
                worker_init_fn=init_seed)
model = ModifiedLightweight3DCNN()
criterion = nn.BCELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.train()
for epoch in range(100):
    for i, (data, label) in enumerate(data_loader):
        print(data.shape)
        data = torch.einsum('b t w c -> b t c w', data)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        print(f'Epoch: {epoch}, Iteration: {i}, Loss: {loss.item()}')
        

#test
model.eval()
#print accuracy, precision, recall, f1 score
correct = 0
total = 0
TP = 0
FP = 0
FN = 0
for i, (data, label) in enumerate(data_loader):
    data = torch.einsum('b t w c -> b t c w', data)
    output = model(data)
    output = output > 0.5
    print(torch.sum(data))
    total += label.size(0)
    correct += (output == label).sum().item()
    TP += ((output == label) & (output == 1)).sum().item()
    FP += ((output != label) & (output == 1)).sum().item()
    FN += ((output != label) & (output == 0)).sum().item()
accuracy = correct / total
precision = TP / ((TP + FP) if TP + FP != 0 else 1)

recall = TP / ((TP + FN) if TP + FN != 0 else 1)
f1 = 2 * precision * recall / (precision + recall)
print(f'Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1 Score: {f1}')

torch.Size([32, 30, 75, 2])
Epoch: 0, Iteration: 0, Loss: 0.6917152404785156
torch.Size([32, 30, 75, 2])
Epoch: 0, Iteration: 1, Loss: 0.6572794318199158
torch.Size([32, 30, 75, 2])
Epoch: 0, Iteration: 2, Loss: 0.7446579337120056
torch.Size([32, 30, 75, 2])
Epoch: 1, Iteration: 0, Loss: 1.0797319412231445
torch.Size([32, 30, 75, 2])
Epoch: 1, Iteration: 1, Loss: 0.7346938848495483
torch.Size([32, 30, 75, 2])
Epoch: 1, Iteration: 2, Loss: 0.6832076907157898
torch.Size([32, 30, 75, 2])
Epoch: 2, Iteration: 0, Loss: 0.7123180031776428
torch.Size([32, 30, 75, 2])
Epoch: 2, Iteration: 1, Loss: 0.6542520523071289
torch.Size([32, 30, 75, 2])
Epoch: 2, Iteration: 2, Loss: 0.6231831908226013
torch.Size([32, 30, 75, 2])
Epoch: 3, Iteration: 0, Loss: 0.652376115322113
torch.Size([32, 30, 75, 2])
Epoch: 3, Iteration: 1, Loss: 0.6213555335998535
torch.Size([32, 30, 75, 2])
Epoch: 3, Iteration: 2, Loss: 0.5539996027946472
torch.Size([32, 30, 75, 2])
Epoch: 4, Iteration: 0, Loss: 0.6905529499053955


In [9]:
for epoch in range(10):
    for i, (data, label) in enumerate(data_loader):
        data = torch.einsum('b t w c -> b t c w', data)
        output = model(data)
        output = output > 0.5
        total += label.size(0)
        correct += (output == label).sum().item()
        TP += ((output == label) & (output == 1)).sum().item()
        FP += ((output != label) & (output == 1)).sum().item()
        FN += ((output != label) & (output == 0)).sum().item()
accuracy = correct / total
precision = TP / ((TP + FP) if TP + FP != 0 else 1)

recall = TP / ((TP + FN) if TP + FN != 0 else 1)
f1 = 2 * precision * recall / (precision + recall)
print(f'Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1 Score: {f1}')

tensor(44834.2266)
tensor(45453.7773)
tensor(44064.6328)
tensor(46969.8828)
tensor(45332.3867)
tensor(45199.8672)
tensor(49869.3672)
tensor(47689.1016)
tensor(40137.6641)
tensor(46183.5352)
tensor(44338.8438)
tensor(45872.5586)
tensor(42721.7500)
tensor(42755.5547)
tensor(47319.9570)
tensor(54596.1719)
tensor(47254.9922)
tensor(38587.5742)
tensor(46108.6172)
tensor(41550.0352)
tensor(41973.7188)
tensor(41001.4180)
tensor(42210.7188)
tensor(47311.0664)
tensor(45119.7930)
tensor(46332.3984)
tensor(46962.8047)
tensor(40514.5859)
tensor(43685.4688)
tensor(43707.6016)
Accuracy: 0.8484848484848485, Precision: 0.8821052631578947, Recall: 0.8011472275334608, F1 Score: 0.8396793587174348
