In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, cohen_kappa_score
from PIL import Image
from tqdm import tqdm
from collections import OrderedDict
from torch.optim.lr_scheduler import ReduceLROnPlateau
from transformers import ViTForImageClassification, ViTFeatureExtractor


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [2]:
print(os.listdir('/kaggle/input/nsl-kdd-dataset/'))

['KDDTest+.arff', 'KDDTest-21.arff', 'KDDTest1.jpg', 'KDDTrain+.txt', 'KDDTrain+_20Percent.txt', 'KDDTest-21.txt', 'KDDTest+.txt', 'KDDTrain+.arff', 'index.html', 'nsl-kdd', 'KDDTrain+_20Percent.arff', 'KDDTrain1.jpg']


In [3]:
zip_path = '/kaggle/input/nsl-kdd-dataset/nsl-kdd/KDDTrain+.txt'  # Adjust based on your dataset name
print(zip_path)

/kaggle/input/nsl-kdd-dataset/nsl-kdd/KDDTrain+.txt


In [4]:
!pip install torchinfo
from torchinfo import summary

  pid, fd = os.forkpty()




In [5]:
train_file_path = '/kaggle/input/nsl-kdd-dataset/nsl-kdd/KDDTrain+.txt'
# Load the dataset into a pandas DataFrame
dfTrain = pd.read_csv(train_file_path, header=None, sep=',')

dfTrain.head()

test_file_path = '/kaggle/input/nsl-kdd-dataset/nsl-kdd/KDDTest+.txt'
# Load the dataset into a pandas DataFrame
dfTest = pd.read_csv(test_file_path, header=None, sep=',')

dfTest.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,33,34,35,36,37,38,39,40,41,42
0,0,tcp,private,REJ,0,0,0,0,0,0,...,0.04,0.06,0.0,0.0,0.0,0.0,1.0,1.0,neptune,21
1,0,tcp,private,REJ,0,0,0,0,0,0,...,0.0,0.06,0.0,0.0,0.0,0.0,1.0,1.0,neptune,21
2,2,tcp,ftp_data,SF,12983,0,0,0,0,0,...,0.61,0.04,0.61,0.02,0.0,0.0,0.0,0.0,normal,21
3,0,icmp,eco_i,SF,20,0,0,0,0,0,...,1.0,0.0,1.0,0.28,0.0,0.0,0.0,0.0,saint,15
4,1,tcp,telnet,RSTO,0,15,0,0,0,0,...,0.31,0.17,0.03,0.02,0.0,0.0,0.83,0.71,mscan,11


In [6]:
# List of attribute names (as per the provided ARFF header)
columns = [
    'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes', 'land', 
    'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in', 'num_compromised', 
    'root_shell', 'su_attempted', 'num_root', 'num_file_creations', 'num_shells', 'num_access_files', 
    'num_outbound_cmds', 'is_host_login', 'is_guest_login', 'count', 'srv_count', 'serror_rate', 
    'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate', 'same_srv_rate', 'diff_srv_rate', 
    'srv_diff_host_rate', 'dst_host_count', 'dst_host_srv_count', 'dst_host_same_srv_rate', 
    'dst_host_diff_srv_rate', 'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 
    'dst_host_serror_rate', 'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate', 
    'class', 'difficulty level'
]

# Assign the column names to the DataFrame
dfTrain.columns = columns

# Assign the column names to the DataFrame
dfTest.columns = columns

# # Display the first few rows to verify
# print(dfTrain.head())
print("number of columns", len(dfTrain.columns))
print("number of columns", len(dfTest.columns))

number of columns 43
number of columns 43


In [7]:
dfTrain.info()
dfTest.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 125973 entries, 0 to 125972
Data columns (total 43 columns):
 #   Column                       Non-Null Count   Dtype  
---  ------                       --------------   -----  
 0   duration                     125973 non-null  int64  
 1   protocol_type                125973 non-null  object 
 2   service                      125973 non-null  object 
 3   flag                         125973 non-null  object 
 4   src_bytes                    125973 non-null  int64  
 5   dst_bytes                    125973 non-null  int64  
 6   land                         125973 non-null  int64  
 7   wrong_fragment               125973 non-null  int64  
 8   urgent                       125973 non-null  int64  
 9   hot                          125973 non-null  int64  
 10  num_failed_logins            125973 non-null  int64  
 11  logged_in                    125973 non-null  int64  
 12  num_compromised              125973 non-null  int64  
 13 

In [8]:
# remove NULL values 

for i in dfTrain.columns:
    # Check Any NA Values
    if dfTrain[i].isnull().sum() > 0:
        print(i, dfTrain[i].isnull().sum())



# Check for duplicate rows
duplicates = dfTrain.duplicated().sum()

# Output the number of duplicate rows
print(f"Number of duplicate rows: {duplicates}")
print("number of columns", len(dfTrain.columns))

Number of duplicate rows: 0
number of columns 43


In [9]:
# remove NULL values 

for i in dfTest.columns:
    # Check Any NA Values
    if dfTest[i].isnull().sum() > 0:
        print(i, dfTest[i].isnull().sum())



# Check for duplicate rows
duplicates = dfTest.duplicated().sum()

# Output the number of duplicate rows
print(f"Number of duplicate rows: {duplicates}")
print("number of columns", len(dfTest.columns))

Number of duplicate rows: 0
number of columns 43


In [10]:
# Get Numeric Columns
numeric_columns = dfTrain.select_dtypes(include=[np.number]).columns
other_columns = dfTrain.columns.difference(numeric_columns)

len(numeric_columns), len(other_columns)

print("NUMERIC COLUMNS:", numeric_columns)
print("OTHER COLUMNS:", other_columns)


output_column = 'class'
other_columns = other_columns.difference([output_column])

numeric_columns, other_columns, output_column
print("number of columns", len(dfTrain.columns))

NUMERIC COLUMNS: Index(['duration', 'src_bytes', 'dst_bytes', 'land', 'wrong_fragment',
       'urgent', 'hot', 'num_failed_logins', 'logged_in', 'num_compromised',
       'root_shell', 'su_attempted', 'num_root', 'num_file_creations',
       'num_shells', 'num_access_files', 'num_outbound_cmds', 'is_host_login',
       'is_guest_login', 'count', 'srv_count', 'serror_rate',
       'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate', 'same_srv_rate',
       'diff_srv_rate', 'srv_diff_host_rate', 'dst_host_count',
       'dst_host_srv_count', 'dst_host_same_srv_rate',
       'dst_host_diff_srv_rate', 'dst_host_same_src_port_rate',
       'dst_host_srv_diff_host_rate', 'dst_host_serror_rate',
       'dst_host_srv_serror_rate', 'dst_host_rerror_rate',
       'dst_host_srv_rerror_rate', 'difficulty level'],
      dtype='object')
OTHER COLUMNS: Index(['class', 'flag', 'protocol_type', 'service'], dtype='object')
number of columns 43


In [11]:
# Get Numeric Columns
numeric_columns2 = dfTest.select_dtypes(include=[np.number]).columns
other_columns2 = dfTest.columns.difference(numeric_columns2)

len(numeric_columns2), len(other_columns2)

print("NUMERIC COLUMNS:", numeric_columns2)
print("OTHER COLUMNS:", other_columns2)


output_column2 = 'class'
other_columns2 = other_columns2.difference([output_column2])

numeric_columns2, other_columns2, output_column2
print("number of columns", len(dfTest.columns))

NUMERIC COLUMNS: Index(['duration', 'src_bytes', 'dst_bytes', 'land', 'wrong_fragment',
       'urgent', 'hot', 'num_failed_logins', 'logged_in', 'num_compromised',
       'root_shell', 'su_attempted', 'num_root', 'num_file_creations',
       'num_shells', 'num_access_files', 'num_outbound_cmds', 'is_host_login',
       'is_guest_login', 'count', 'srv_count', 'serror_rate',
       'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate', 'same_srv_rate',
       'diff_srv_rate', 'srv_diff_host_rate', 'dst_host_count',
       'dst_host_srv_count', 'dst_host_same_srv_rate',
       'dst_host_diff_srv_rate', 'dst_host_same_src_port_rate',
       'dst_host_srv_diff_host_rate', 'dst_host_serror_rate',
       'dst_host_srv_serror_rate', 'dst_host_rerror_rate',
       'dst_host_srv_rerror_rate', 'difficulty level'],
      dtype='object')
OTHER COLUMNS: Index(['class', 'flag', 'protocol_type', 'service'], dtype='object')
number of columns 43


In [12]:
# Number of Distinct Values in Each Column
for i in other_columns:
    print(i, " -> Number of Distinct Values:", dfTrain[i].nunique())

flag  -> Number of Distinct Values: 11
protocol_type  -> Number of Distinct Values: 3
service  -> Number of Distinct Values: 70


In [13]:
columnNames = []
for col in dfTrain.columns:
    columnNames.append(col)

print("columnNames = ",columnNames)

columnNames =  ['duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes', 'land', 'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in', 'num_compromised', 'root_shell', 'su_attempted', 'num_root', 'num_file_creations', 'num_shells', 'num_access_files', 'num_outbound_cmds', 'is_host_login', 'is_guest_login', 'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate', 'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate', 'dst_host_count', 'dst_host_srv_count', 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate', 'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate', 'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate', 'class', 'difficulty level']


In [14]:
dfTrain.head()
dfTest.head()

Unnamed: 0,duration,protocol_type,service,flag,src_bytes,dst_bytes,land,wrong_fragment,urgent,hot,...,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate,class,difficulty level
0,0,tcp,private,REJ,0,0,0,0,0,0,...,0.04,0.06,0.0,0.0,0.0,0.0,1.0,1.0,neptune,21
1,0,tcp,private,REJ,0,0,0,0,0,0,...,0.0,0.06,0.0,0.0,0.0,0.0,1.0,1.0,neptune,21
2,2,tcp,ftp_data,SF,12983,0,0,0,0,0,...,0.61,0.04,0.61,0.02,0.0,0.0,0.0,0.0,normal,21
3,0,icmp,eco_i,SF,20,0,0,0,0,0,...,1.0,0.0,1.0,0.28,0.0,0.0,0.0,0.0,saint,15
4,1,tcp,telnet,RSTO,0,15,0,0,0,0,...,0.31,0.17,0.03,0.02,0.0,0.0,0.83,0.71,mscan,11


In [15]:
from sklearn.preprocessing import MinMaxScaler

# Normalize numerical features
# numerical_columns = dfTrain.columns[:-1]  # All columns except 'class'
scaler = MinMaxScaler()
dfTrain[numeric_columns] = scaler.fit_transform(dfTrain[numeric_columns])
dfTest[numeric_columns] = scaler.fit_transform(dfTest[numeric_columns])

In [16]:
dfTrain.head()

Unnamed: 0,duration,protocol_type,service,flag,src_bytes,dst_bytes,land,wrong_fragment,urgent,hot,...,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate,class,difficulty level
0,0.0,tcp,ftp_data,SF,3.558064e-07,0.0,0.0,0.0,0.0,0.0,...,0.17,0.03,0.17,0.0,0.0,0.0,0.05,0.0,normal,0.952381
1,0.0,udp,other,SF,1.057999e-07,0.0,0.0,0.0,0.0,0.0,...,0.0,0.6,0.88,0.0,0.0,0.0,0.0,0.0,normal,0.714286
2,0.0,tcp,private,S0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.1,0.05,0.0,0.0,1.0,1.0,0.0,0.0,neptune,0.904762
3,0.0,tcp,http,SF,1.681203e-07,6.223962e-06,0.0,0.0,0.0,0.0,...,1.0,0.0,0.03,0.04,0.03,0.01,0.0,0.01,normal,1.0
4,0.0,tcp,http,SF,1.442067e-07,3.20626e-07,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,normal,1.0


In [17]:
# label encode all columns 
# Import label encoder 

from sklearn import preprocessing 
  
# label_encoder object knows  
# how to understand word labels. 
label_encoder = preprocessing.LabelEncoder() 

# df['species'].unique() 

for col in other_columns:
    dfTrain[col]= label_encoder.fit_transform(dfTrain[col]) 

## label encoding for class column 
dfTrain['class'] = label_encoder.fit_transform(dfTrain['class'])


## for test dataset
for col in other_columns2:
    dfTest[col]= label_encoder.fit_transform(dfTest[col]) 

## label encoding for class column 
dfTest['class'] = label_encoder.fit_transform(dfTest['class'])

In [18]:
for col in dfTrain.columns:
    print("number of unique values in column ",col, " = ", dfTrain[col].nunique())

number of unique values in column  duration  =  2981
number of unique values in column  protocol_type  =  3
number of unique values in column  service  =  70
number of unique values in column  flag  =  11
number of unique values in column  src_bytes  =  3341
number of unique values in column  dst_bytes  =  9326
number of unique values in column  land  =  2
number of unique values in column  wrong_fragment  =  3
number of unique values in column  urgent  =  4
number of unique values in column  hot  =  28
number of unique values in column  num_failed_logins  =  6
number of unique values in column  logged_in  =  2
number of unique values in column  num_compromised  =  88
number of unique values in column  root_shell  =  2
number of unique values in column  su_attempted  =  3
number of unique values in column  num_root  =  82
number of unique values in column  num_file_creations  =  35
number of unique values in column  num_shells  =  3
number of unique values in column  num_access_files  

In [19]:
df_subset = dfTrain.head(10000)
df_subset.head()

df_test_subset = dfTest.head(1000)
df_test_subset.head()

Unnamed: 0,duration,protocol_type,service,flag,src_bytes,dst_bytes,land,wrong_fragment,urgent,hot,...,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate,class,difficulty level
0,0.0,1,45,1,0.0,0.0,0.0,0.0,0.0,0.0,...,0.04,0.06,0.0,0.0,0.0,0.0,1.0,1.0,14,1.0
1,0.0,1,45,1,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.06,0.0,0.0,0.0,0.0,1.0,1.0,14,1.0
2,3.5e-05,1,19,9,0.0002066513,0.0,0.0,0.0,0.0,0.0,...,0.61,0.04,0.61,0.02,0.0,0.0,0.0,0.0,16,1.0
3,0.0,0,13,9,3.183413e-07,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,1.0,0.28,0.0,0.0,0.0,0.0,24,0.714286
4,1.7e-05,1,55,2,0.0,1.1e-05,0.0,0.0,0.0,0.0,...,0.31,0.17,0.03,0.02,0.0,0.0,0.83,0.71,11,0.52381


In [20]:
for col in df_subset.columns:
    print("number of unique values in column ",col, " = ", df_subset[col].nunique())

number of unique values in column  duration  =  340
number of unique values in column  protocol_type  =  3
number of unique values in column  service  =  64
number of unique values in column  flag  =  11
number of unique values in column  src_bytes  =  992
number of unique values in column  dst_bytes  =  2140
number of unique values in column  land  =  1
number of unique values in column  wrong_fragment  =  3
number of unique values in column  urgent  =  2
number of unique values in column  hot  =  18
number of unique values in column  num_failed_logins  =  3
number of unique values in column  logged_in  =  2
number of unique values in column  num_compromised  =  17
number of unique values in column  root_shell  =  2
number of unique values in column  su_attempted  =  3
number of unique values in column  num_root  =  16
number of unique values in column  num_file_creations  =  10
number of unique values in column  num_shells  =  2
number of unique values in column  num_access_files  = 

In [21]:
# Check the number of samples for each output class in the 'class' column
class_counts = dfTrain['class'].value_counts()

# Display the result
print(class_counts)

class
11    67343
9     41214
17     3633
5      3599
15     2931
18     2646
10     1493
0       956
20      892
21      890
14      201
3        53
1        30
22       20
6        18
4        11
16       10
7         9
2         8
8         7
13        4
12        3
19        2
Name: count, dtype: int64


In [22]:
# Check the number of samples for each output class in the 'class' column
class_counts2 = df_subset['class'].value_counts()

# Display the result
print(class_counts2)

class
11    5292
9     3336
17     273
5      271
15     242
18     212
10     121
20      86
21      72
0       69
14      13
3        3
22       3
1        3
8        1
2        1
16       1
4        1
Name: count, dtype: int64


In [23]:
pip install imbalanced-learn

  pid, fd = os.forkpty()


Note: you may need to restart the kernel to use updated packages.


In [24]:
from imblearn.over_sampling import RandomOverSampler
import pandas as pd

# Assuming 'df' is your DataFrame, and it contains the feature columns and 'class' as the target column
X = df_subset.drop('class', axis=1)  # Feature columns
y = df_subset['class']  # Target column

# Define the oversampling strategy (for classes with less than 80 samples)
oversample_strategy = {label: 115 for label in y.value_counts()[y.value_counts() < 80].index}

# Apply random oversampling
ros = RandomOverSampler(sampling_strategy=oversample_strategy, random_state=42)
X_resampled, y_resampled = ros.fit_resample(X, y)

# Create a new DataFrame with resampled data
df_resampled = pd.DataFrame(X_resampled, columns=X.columns)
df_resampled['class'] = y_resampled

# Check the class distribution after oversampling
print(df_resampled['class'].value_counts())


class
11    5292
9     3336
17     273
5      271
15     242
18     212
10     121
21     115
14     115
0      115
16     115
3      115
2      115
8      115
4      115
1      115
22     115
20      86
Name: count, dtype: int64


In [25]:
x_train = df_subset.iloc[:, :-1]
y_train = df_subset.iloc[:, -1]

x_test = df_test_subset.iloc[:, :-1]
y_test = df_test_subset.iloc[:, -1]


In [26]:
print(type(x_train))
print(x_train.shape)


<class 'pandas.core.frame.DataFrame'>
(10000, 42)


In [27]:
## below are the most imp features to detect anamolies or attacks
crucial_columns = ['src_bytes', 'dst_bytes', 'count', 'srv_count', 'serror_rate', 'rerror_rate', 
                   'protocol_type', 'num_failed_logins', 'num_root', 'root_shell', 'num_compromised', 
                   'su_attempted','dst_host_count', 'dst_host_same_srv_rate', 'dst_host_serror_rate', 'hot' ]


In [28]:
df_subset.head()

Unnamed: 0,duration,protocol_type,service,flag,src_bytes,dst_bytes,land,wrong_fragment,urgent,hot,...,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate,class,difficulty level
0,0.0,1,20,9,3.558064e-07,0.0,0.0,0.0,0.0,0.0,...,0.17,0.03,0.17,0.0,0.0,0.0,0.05,0.0,11,0.952381
1,0.0,2,44,9,1.057999e-07,0.0,0.0,0.0,0.0,0.0,...,0.0,0.6,0.88,0.0,0.0,0.0,0.0,0.0,11,0.714286
2,0.0,1,49,5,0.0,0.0,0.0,0.0,0.0,0.0,...,0.1,0.05,0.0,0.0,1.0,1.0,0.0,0.0,9,0.904762
3,0.0,1,24,9,1.681203e-07,6.223962e-06,0.0,0.0,0.0,0.0,...,1.0,0.0,0.03,0.04,0.03,0.01,0.0,0.01,11,1.0
4,0.0,1,24,9,1.442067e-07,3.20626e-07,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,11,1.0


In [29]:
# Select the crucial columns
total_data = x_train[crucial_columns].values

# Rescale the values to [0, 255] if they're in the range [0, 1]
total_data = (total_data * 255).astype(np.uint8)

train_data = total_data[:8000]
val_data = total_data[8000:]

labels = x_train.iloc[:, -1].values
train_labels = labels[:8000]
val_labels = labels[8000:]

###-------------------------------------------------test dataset-----------------------------
# Select the crucial columns
total_data2 = x_test[crucial_columns].values

# Rescale the values to [0, 255] if they're in the range [0, 1]
total_data2 = (total_data2 * 255).astype(np.uint8)

test_data = total_data2[:1000]

test_labels = x_test.iloc[:, -1].values
test_labels = test_labels[:1000]


In [30]:
# Reshape each row of 16 features into a 4x4 image
# Assuming the number of rows is 'num_samples'
train_images = train_data.reshape(-1, 4, 4)  # Reshape into 4x4 images for each sample
val_images = val_data.reshape(-1, 4, 4)  # Reshape into 4x4 images for each sample
test_images = test_data.reshape(-1, 4, 4) 

In [31]:
## convert all train_images into rgb 

train_images_rgb = []
for i, image in enumerate(train_images):
    if len(image.shape) == 2:
        
        # Convert grayscale to RGB by stacking it into 3 channels
        rgb_image = np.stack((image,) * 3, axis=-1)
        
        # Replace the grayscale image with the RGB image in the list
        train_images_rgb.append(rgb_image)

In [32]:
val_images_rgb = []
for i, image in enumerate(val_images):
    if len(image.shape) == 2:
        
        # Convert grayscale to RGB by stacking it into 3 channels
        rgb_image = np.stack((image,) * 3, axis=-1)
        
        # Replace the grayscale image with the RGB image in the list
        val_images_rgb.append(rgb_image)

In [33]:
test_images_rgb = []
for i, image in enumerate(test_images):
    if len(image.shape) == 2:
        
        # Convert grayscale to RGB by stacking it into 3 channels
        rgb_image = np.stack((image,) * 3, axis=-1)
        
        # Replace the grayscale image with the RGB image in the list
        test_images_rgb.append(rgb_image)

In [34]:
import numpy as np

def is_rgb(image):
    """
    Check if the image is RGB.
    """
    # Check the dimensions of the image
    if len(image.shape) == 3 and image.shape[-1] == 3:
        return True
    elif len(image.shape) == 2:  # Grayscale images have 2 dimensions
        return False
    else:
        return False  # Other formats (e.g., RGBA) need specific handling


In [35]:
list = []
list2 = []
list3 = []

for image in train_images_rgb:
    list.append(is_rgb(image))

print(len(list))

for image in val_images_rgb:
    list2.append(is_rgb(image))

print(len(list2))

for image in test_images_rgb:
    list3.append(is_rgb(image))

print(len(list3))

8000
2000
1000


In [36]:
# # Visualize a sample image from the training set
# plt.imshow(images[0], cmap='gray')
# plt.show()


In [37]:
# Define ViT input size (use a larger size than 4x4 for practical ViT use, e.g., 32x32)

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to a larger size for ViT
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),  # Convert to tensor
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to a larger size for ViT
    transforms.ToTensor(),  # Convert to tensor
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to a larger size for ViT
    transforms.ToTensor(),  # Convert to tensor
])


In [38]:
class CustomDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images  # The images (4x4 reshaped into images, can be resized later)
        self.labels = labels  # The corresponding labels
        self.transform = transform  # Transformation (resize, normalization, etc.)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        # Convert numpy array to PIL Image before applying transformations
        image = Image.fromarray(image)  # Ensure the image is in PIL format
        
        # Apply the transformations (resize, normalize, etc.)
        if self.transform:
            image = self.transform(image)

        # Convert the label to a tensor
        label = torch.tensor(label, dtype=torch.long)  # Ensure labels are tensors (long for classification)
        
        return image, label

In [39]:
# Evaluate the model
def validate_model():

    model.eval()
    correct_preds = 0
    total_preds = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images.squeeze(1))
            _, preds = torch.max(outputs.logits, 1)
            correct_preds += (preds == labels).sum().item()
            total_preds += labels.size(0)
    
    test_acc = correct_preds / total_preds
    print(f"Test Accuracy: {test_acc:.4f}")



In [40]:
# Training loop

def train_model(num_epochs = 10):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0
        
        for images, labels in tqdm(train_loader, desc="Training", ):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images.squeeze(1))  # Remove the extra dimension for image (e.g., batch_size, 1, 7, 7)
            loss = criterion(outputs.logits, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, preds = torch.max(outputs.logits, 1)
            correct_preds += (preds == labels).sum().item()
            total_preds += labels.size(0)
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct_preds / total_preds
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

        validate_model()



In [41]:
# Create Dataset objects
train_dataset = CustomDataset(train_images_rgb, train_labels, transform=train_transform)
val_dataset = CustomDataset(val_images_rgb, val_labels, transform=val_transform)
test_dataset = CustomDataset(test_images_rgb, test_labels, transform=test_transform)

# Create DataLoader objects
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)


In [42]:
num_classes = 18 ## number of unique values in column class in train data including normal class

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')


# Modify the input size (224x224) to 7x7 (because we are feeding in 7x7 grids)
model.config.image_size = 224
model.config.patch_size = 16  # Set patch size to 1 since we're feeding in small 7x7 patches
model.config.num_channels = 3  # Ensure the model expects RGB input (3 channels)


# Freeze the layers if needed
for param in model.parameters():
    param.requires_grad = False  # Optionally freeze model layers

# Unfreeze the classification head
for param in model.classifier.parameters():
    param.requires_grad = True  # Allow training of the classifier head

model = nn.DataParallel(model)
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [None]:
train_model(50)

  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):


Epoch [1/50], Loss: 0.4877, Accuracy: 0.8900
Test Accuracy: 0.9300
Epoch [2/50], Loss: 0.2558, Accuracy: 0.9346
Test Accuracy: 0.9360
Epoch [3/50], Loss: 0.2285, Accuracy: 0.9376
Test Accuracy: 0.9390
Epoch [4/50], Loss: 0.2138, Accuracy: 0.9407
Test Accuracy: 0.9320
Epoch [5/50], Loss: 0.2054, Accuracy: 0.9419
Test Accuracy: 0.9415
Epoch [6/50], Loss: 0.1921, Accuracy: 0.9446
Test Accuracy: 0.9300
Epoch [7/50], Loss: 0.1841, Accuracy: 0.9470
Test Accuracy: 0.9440
Epoch [8/50], Loss: 0.1882, Accuracy: 0.9447
Test Accuracy: 0.9495
Epoch [9/50], Loss: 0.1748, Accuracy: 0.9465
Test Accuracy: 0.9490


In [None]:
validate_model()