In [1]:
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import torch.nn as nn
import os
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

In [3]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.fc = nn.Linear(1000, 6)

    def forward(self, x):
        x = model(**x).logits[0]
        x = F.gelu(self.fc(x))
        return x

net = ConvNet()
#GPU = torch.device("cuda")
#net.to(GPU)

for name, param in model.named_parameters():
    if not ('classifier' in name or 'fc' in name):
        param.requires_grad = False

for param in net.fc.parameters():
    param.requires_grad = True

In [18]:
loss = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.00001)

In [5]:
def train_nn(question_tensor, answer_tensor, model, loss=loss, optimizer=optimizer):
    optimizer.zero_grad()
    outputs = model(question_tensor)
    loss = loss(outputs, answer_tensor)
    loss.backward()
    optimizer.step()

In [6]:
def preprocess_image(image_path):
    image = Image.open(image_path)
    inputs = processor(image, return_tensors="pt")
    return inputs

In [7]:
data = [[os.listdir(r"C:/Users/user/Desktop/text nn/injury dataset/abrasion"), torch.tensor([1, 0, 0, 0, 0, 0], dtype=torch.float32)],
        [os.listdir(r"C:/Users/user/Desktop/text nn/injury dataset/allergic"), torch.tensor([0, 1, 0, 0, 0, 0], dtype=torch.float32)],
        [os.listdir(r"C:/Users/user/Desktop/text nn/injury dataset/blisters"), torch.tensor([0, 0, 1, 0, 0, 0], dtype=torch.float32)],
        [os.listdir(r"C:/Users/user/Desktop/text nn/injury dataset/bruises"), torch.tensor([0, 0, 0, 1, 0, 0], dtype=torch.float32)],
        [os.listdir(r"C:/Users/user/Desktop/text nn/injury dataset/burn"), torch.tensor([0, 0, 0, 0, 1, 0], dtype=torch.float32)],
        [os.listdir(r"C:/Users/user/Desktop/text nn/injury dataset/laceration"), torch.tensor([0, 0, 0, 0, 0, 1], dtype=torch.float32)]]

dirs = [r"C:/Users/user/Desktop/text nn/injury dataset/abrasion",
        r"C:/Users/user/Desktop/text nn/injury dataset/allergic",
        r"C:/Users/user/Desktop/text nn/injury dataset/blisters",
        r"C:/Users/user/Desktop/text nn/injury dataset/bruises",
        r"C:/Users/user/Desktop/text nn/injury dataset/burn",
        r"C:/Users/user/Desktop/text nn/injury dataset/laceration"]

questions = []
answers = []

for x in tqdm(range(len(data))):
        for y in data[x][0]:
                questions.append(preprocess_image(f"{dirs[x]}/{y}"))
                answers.append(data[x][1])

val_num = 25
val_questions = []
val_answers = []

for x in tqdm(range(val_num)):
        chosen = random.randint(0, len(questions) - 1)
        
        val_questions.append(questions[chosen])
        val_answers.append(answers[chosen])
        
        questions.pop(chosen)
        answers.pop(chosen)

100%|██████████| 6/6 [00:10<00:00,  1.83s/it]
100%|██████████| 25/25 [00:00<?, ?it/s]


In [8]:
combined = list(zip(questions, answers))
random.shuffle(combined)
questions, answers = zip(*combined)

questions = list(questions)
answers = list(answers)

In [9]:
for x in tqdm(range(len(questions))):
    train_nn(questions[x], answers[x], net)

100%|██████████| 2009/2009 [13:38<00:00,  2.45it/s]


In [15]:
errors = 0

for x in tqdm(range(len(val_questions))):
    if torch.argmax(net(val_questions[x])) != torch.argmax(val_answers[x]):
        errors += 1
        print(x)

  4%|▍         | 1/25 [00:00<00:08,  2.86it/s]

0


  8%|▊         | 2/25 [00:00<00:08,  2.86it/s]

1


 12%|█▏        | 3/25 [00:01<00:07,  2.84it/s]

2


 20%|██        | 5/25 [00:01<00:07,  2.79it/s]

4


 28%|██▊       | 7/25 [00:02<00:06,  2.79it/s]

6


 32%|███▏      | 8/25 [00:02<00:06,  2.72it/s]

7


 40%|████      | 10/25 [00:03<00:05,  2.74it/s]

9


 44%|████▍     | 11/25 [00:03<00:05,  2.74it/s]

10


 48%|████▊     | 12/25 [00:04<00:04,  2.75it/s]

11


 56%|█████▌    | 14/25 [00:05<00:03,  2.76it/s]

13


 60%|██████    | 15/25 [00:05<00:03,  2.76it/s]

14


 64%|██████▍   | 16/25 [00:05<00:03,  2.74it/s]

15


 68%|██████▊   | 17/25 [00:06<00:02,  2.71it/s]

16


 72%|███████▏  | 18/25 [00:06<00:02,  2.72it/s]

17


 76%|███████▌  | 19/25 [00:06<00:02,  2.74it/s]

18


 84%|████████▍ | 21/25 [00:07<00:01,  2.75it/s]

20


 92%|█████████▏| 23/25 [00:08<00:00,  2.76it/s]

22


 96%|█████████▌| 24/25 [00:08<00:00,  2.75it/s]

23


100%|██████████| 25/25 [00:09<00:00,  2.76it/s]

24





In [16]:
net(val_questions[0])

tensor([ 0.0735,  0.3338,  0.4180,  0.0399,  0.1312, -0.1150],
       grad_fn=<GeluBackward0>)

In [17]:
val_answers[0]

tensor([0., 1., 0., 0., 0., 0.])

In [14]:
errors

19