In [2]:
import os
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt

In [3]:
labelled_files_video = [f"labeled/{n}" for n in os.listdir("labeled") if ".hevc" in n ]
labelled_files_video

['labeled/3.hevc',
 'labeled/0.hevc',
 'labeled/4.hevc',
 'labeled/2.hevc',
 'labeled/1.hevc']

In [4]:
labelled_files_label = [n.replace('.hevc', '.txt') for n in labelled_files_video]
labelled_files_label

['labeled/3.txt',
 'labeled/0.txt',
 'labeled/4.txt',
 'labeled/2.txt',
 'labeled/1.txt']

In [5]:
unlabelled_files_video = [f"unlabeled/{n}" for n in os.listdir("unlabeled") if ".hevc" in n ]
unlabelled_files_video

['unlabeled/5.hevc',
 'unlabeled/6.hevc',
 'unlabeled/9.hevc',
 'unlabeled/8.hevc',
 'unlabeled/7.hevc']

In [6]:
IMAGE_SHAPE = (256, 256)

In [8]:
def data_loader(ds="train"):
    if ds == "train":
        for video_file, label_file in zip(labelled_files_video, labelled_files_label):
            vid = cv2.VideoCapture(video_file)
            lab = torch.from_numpy(np.nan_to_num(np.loadtxt(label_file))).unsqueeze(0).float()
            idx = 0
            while True:
                next_frame_found, frame = vid.read()
                if not next_frame_found:
                    break
                frame = cv2.resize(frame, IMAGE_SHAPE)
                image = torch.from_numpy(np.array(frame, np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0)
                label = lab[:, idx, :]
                idx += 1
                yield image, label

    else:
        for video_file in unlabelled_files_video:
            vid = cv2.VideoCapture(video_file)
            while True:
                next_frame_found, frame = vid.read()
                if not next_frame_found:
                    break
                frame = cv2.resize(frame, IMAGE_SHAPE)
                image = torch.from_numpy(np.array(frame, np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0)
                yield image, None

next(data_loader())[0].shape, next(data_loader())[1].shape

(torch.Size([1, 3, 256, 256]), torch.Size([1, 2]))

In [10]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Conv2d(3, 8, 2, 2)
        self.l2 = torch.nn.Conv2d(8, 16, 2, 2)
        self.l3 = torch.nn.ConvTranspose2d(16, 8, 2, 2)
        self.l4 = torch.nn.ConvTranspose2d(8, 1, 2, 2)
        self.l5 = torch.nn.Flatten()
        self.l6 = torch.nn.Linear(IMAGE_SHAPE[0] * IMAGE_SHAPE[1], 32)
        self.l7 = torch.nn.Linear(32, 2)

    
    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        x = self.l4(x)
        x = self.l5(x)
        x = self.l6(x)
        x = self.l7(x)
        return x

model = Model()

In [11]:
print("Parameter count: ", sum((p.numel() for p in model.parameters() if p.requires_grad)) / 1e6 , " million")

Parameter count:  2.098435  million


# Train

In [12]:
EPOCHS = 1
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

for epoch in range(EPOCHS):
    train_data = data_loader("train")
    model.train()
    for image, label in train_data:
        optim.zero_grad()
        out = model(image)
        loss = loss_fn(out, label)
        loss.backward()
        optim.step()

    model.eval()
    train_data = data_loader("train")
    train_loss = 0.0
    count = 0
    with torch.no_grad():
        for image, label in train_data:
            out = model(image)
            loss = loss_fn(out, label)
            train_loss += loss
            count += 1

    print(f"loss epoch {epoch}  --  {train_loss / count}")


loss epoch 0  --  0.04395794868469238


In [31]:
torch.save(model.state_dict(), "best-model.pt")

# Get predictions on validation data

In [27]:
valid_data = data_loader(ds="valid")
valid_data_preds = []
file_idx = 0
for idx, (image, _) in enumerate(valid_data):
    valid_data_preds.append(model(image).detach().numpy())

    if len(valid_data_preds) == 1200:  # reset every 1200 predictions
        predictions = np.concatenate(valid_data_preds, 0)
        np.savetxt(unlabelled_files_video[file_idx].replace("hevc", "txt"), predictions)
        file_idx += 1
        valid_data_preds = []
