In [4]:
%pip install torch pandas torchvision scikit-learn tqdm kaggle torchmetrics huggingface_hub -q

Note: you may need to restart the kernel to use updated packages.


In [6]:
from huggingface_hub import hf_hub_download
import os

model_files = [
    "mendeley.zip"
]

# Target local directory
local_dir = "."
os.makedirs(local_dir, exist_ok=True)

# Download each model
for filename in model_files:
    hf_hub_download(
        repo_id="omkar334/agri",
        filename=filename,
        local_dir=local_dir,
    )

!unzip -q mendeley.zip -d data3

In [11]:
from dataset import get_dataloaders
train_loader_labeled, val_loader_labeled, train_loader_unlabeled = get_dataloaders('data3/mendeley', 16)

In [12]:
from convnext import Autoencoder
import torch
import torch.nn as nn
import torch.optim as optim


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Autoencoder(11).to(DEVICE)
reconstruction_criterion = nn.MSELoss()  # Unsupervised Loss
classification_criterion = nn.CrossEntropyLoss()  # Supervised Loss
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
from train import train_model, validate_model

model = train_model(
    model,
    train_loader_labeled,
    train_loader_unlabeled,
    reconstruction_criterion,
    classification_criterion,
    optimizer,
    num_epochs=35,
)

                                                             

Epoch [1/35], Total Loss: 714.1368, Reconstruction Loss: 26.1035, Classification Loss: 688.0333, Accuracy: 0.2678


                                                             

Epoch [2/35], Total Loss: 663.0012, Reconstruction Loss: 25.3293, Classification Loss: 637.6719, Accuracy: 0.3377


                                                             

Epoch [3/35], Total Loss: 602.5347, Reconstruction Loss: 25.0346, Classification Loss: 577.5000, Accuracy: 0.4041


                                                             

Epoch [4/35], Total Loss: 525.2872, Reconstruction Loss: 24.8172, Classification Loss: 500.4700, Accuracy: 0.4809


                                                             

Epoch [5/35], Total Loss: 394.1712, Reconstruction Loss: 24.6765, Classification Loss: 369.4947, Accuracy: 0.6209


                                                             

Epoch [6/35], Total Loss: 243.8015, Reconstruction Loss: 24.5533, Classification Loss: 219.2482, Accuracy: 0.7789


                                                             

Epoch [7/35], Total Loss: 141.3469, Reconstruction Loss: 24.4626, Classification Loss: 116.8843, Accuracy: 0.8813


                                                             

Epoch [8/35], Total Loss: 90.6617, Reconstruction Loss: 24.4028, Classification Loss: 66.2588, Accuracy: 0.9377


                                                             

Epoch [9/35], Total Loss: 68.7579, Reconstruction Loss: 24.3541, Classification Loss: 44.4038, Accuracy: 0.9562


                                                              

Epoch [10/35], Total Loss: 55.8705, Reconstruction Loss: 24.3054, Classification Loss: 31.5651, Accuracy: 0.9742


                                                              

Epoch [11/35], Total Loss: 46.0644, Reconstruction Loss: 24.2831, Classification Loss: 21.7812, Accuracy: 0.9815


                                                              

Epoch [12/35], Total Loss: 50.7490, Reconstruction Loss: 24.2617, Classification Loss: 26.4873, Accuracy: 0.9770


                                                              

Epoch [13/35], Total Loss: 52.6959, Reconstruction Loss: 24.2629, Classification Loss: 28.4330, Accuracy: 0.9714


                                                              

Epoch [14/35], Total Loss: 41.6377, Reconstruction Loss: 24.2348, Classification Loss: 17.4029, Accuracy: 0.9842


                                                              

Epoch [15/35], Total Loss: 42.6851, Reconstruction Loss: 24.2273, Classification Loss: 18.4578, Accuracy: 0.9824


                                                              

Epoch [16/35], Total Loss: 37.7725, Reconstruction Loss: 24.2023, Classification Loss: 13.5702, Accuracy: 0.9853


                                                              

Epoch [17/35], Total Loss: 43.1390, Reconstruction Loss: 24.1973, Classification Loss: 18.9417, Accuracy: 0.9809


                                                              

Epoch [18/35], Total Loss: 34.3393, Reconstruction Loss: 24.1363, Classification Loss: 10.2030, Accuracy: 0.9891


                                                              

Epoch [19/35], Total Loss: 32.8811, Reconstruction Loss: 24.1178, Classification Loss: 8.7634, Accuracy: 0.9900


                                                              

Epoch [20/35], Total Loss: 32.6262, Reconstruction Loss: 24.0967, Classification Loss: 8.5294, Accuracy: 0.9894


                                                              

Epoch [21/35], Total Loss: 54.4405, Reconstruction Loss: 24.1539, Classification Loss: 30.2866, Accuracy: 0.9705


                                                              

Epoch [22/35], Total Loss: 38.8598, Reconstruction Loss: 24.1209, Classification Loss: 14.7389, Accuracy: 0.9850


                                                              

Epoch [23/35], Total Loss: 35.7974, Reconstruction Loss: 24.0898, Classification Loss: 11.7076, Accuracy: 0.9868


                                                              

Epoch [24/35], Total Loss: 31.6464, Reconstruction Loss: 24.0639, Classification Loss: 7.5826, Accuracy: 0.9911


                                                              

Epoch [25/35], Total Loss: 31.1832, Reconstruction Loss: 24.0458, Classification Loss: 7.1374, Accuracy: 0.9902


                                                              

Epoch [26/35], Total Loss: 32.2272, Reconstruction Loss: 24.0435, Classification Loss: 8.1837, Accuracy: 0.9896


                                                              

Epoch [27/35], Total Loss: 31.3348, Reconstruction Loss: 24.0370, Classification Loss: 7.2978, Accuracy: 0.9894


Epoch 28/35:  73%|███████▎  | 245/337 [01:23<00:30,  2.98it/s]

In [16]:
torch.save(model.state_dict(), 'models/convnext_mendeley.pth')

In [18]:
validate_model(model, val_loader_labeled, 11)

                                                           

Validation Accuracy: 0.6914
Validation Precision: 0.7020
Validation Recall: 0.6241
Validation F1 Score: 0.6460
Number of samples validated on: 256


In [15]:
import time