In [6]:
import torch
from torchvision.datasets import ImageFolder
from torchvision import transforms
from sentence_transformers import SentenceTransformer, util
import random

# Load the trained model
model_path = 'C:/Users/jiyoo/workspace/MakeAIWork3/project3/src/apple_resnet_classifier.pt'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load(model_path, map_location=device)
model.to(device)
model.eval()

# Define the image transformations
transform_img_normal = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Prompt the user for the folder location
folder_url = input('Enter folder location: ')

# Load the dataset
dataset = ImageFolder(folder_url, transform=transform_img_normal)
dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=False)

# Function to predict the class labels
def predict(model, data):
    class_labels = ['Bad Apple', 'Normal Apple', 'Rot Apple', 'Scab Apple']
    class_counts = [0, 0, 0, 0]

    with torch.no_grad():
        for inputs, labels in data:
            inputs = inputs.to(device)
            out = model(inputs)
            _, predicted = torch.max(out.data, 1)

            for p in predicted:
                class_counts[p.item()] += 1

    label_counts_dict = {label: count for label, count in zip(class_labels, class_counts)}
    return label_counts_dict

# Perform prediction on the dataset
apple_counts = predict(model, dataset_loader)

# Get the apple counts from the prediction result
apple_counts = list(apple_counts.values())

# Specify the desired sample size
sample_size = int(input('Enter the sample size: '))

# Randomly select test apples from available_apples
random_test_apples = random.choices(apple_counts, k=sample_size)

print('Randomly selected apples:', random_test_apples)

# Include the AQL calculation based on sample size
aql_calculation = f"The AQL of this batch can be calculated from {sample_size} apples."
print(aql_calculation)

# Create a list to store the passages
mylist = [
    aql_calculation,
    f"n is sample size and AQL is calculated based on the number of 'Normal Apple's from all {sample_size} number of sample_size apples.",
    f"The class of apples is based on {random_test_apples} in every time",
    "AQL(Acceptable Quality Limit) of class 1 is 0.4 or less than 0.4 and it can be accepted as class 1 if n(batch size)=32 and x(the number of not Normal Apple) < 1.",
    "AQL of class 2 is more than 0.4 and less than 6.5 or equal to 6.5 and it can be accepted as class 2 if n=20 and x<8.",
    "AQL of class 3 is more than 6.5 and less than 15 or equal to 15 and it can be accepted as class 3 if n=20 and x<15.",
    "AQL of class 4 is more than 15 and it can be class 4 if n=20 and x >= 15.",
    "The quality of the batch is very good if it is class 1.",
    "The quality of the batch is bad if it is class 4.",
    "Only 'Normal Apple' can be accepted in AQL"
]

# Load the SentenceTransformer model
modelchat = SentenceTransformer('all-MiniLM-L12-v2')

# Encode the passages
passage_embeddings = modelchat.encode(mylist)

# Prompt the user for a query
query = input('What is your question? ')

# Encode the query
query_embedding = modelchat.encode([query])

# Calculate the similarity scores
similarity_scores = util.cos_sim(query_embedding, passage_embeddings)

# Find the index of the most similar passage
most_similar_index = similarity_scores.argmax()

# Retrieve the most similar passage and its similarity score
most_similar_passage = mylist[most_similar_index]
similarity_score = similarity_scores[0][most_similar_index]

# Print the most similar passage and its similarity score
print('Most similar passage:', most_similar_passage)
print('Similarity score:', similarity_score)

# Check if the most similar passage contains AQL information
if 'AQL' in most_similar_passage:
    # Extract AQL from the most similar passage
    aql_start = most_similar_passage.find('AQL') + 4
    aql_end = most_similar_passage.find('is') - 1
    aql = most_similar_passage[aql_start:aql_end]

    # Print the AQL
    print('AQL:', aql)
else:
    print('AQL information not found in the most similar passage.')


Randomly selected apples: [141, 141, 108, 120, 152, 108, 141, 120, 152, 152, 141, 141, 152, 152, 152, 152, 152, 141, 108, 141, 108, 108, 108, 152, 152, 141, 141, 120, 120, 152, 120, 152]
The AQL of this batch can be calculated from 32 apples.
Most similar passage: The quality of the batch is very good if it is class 1.
Similarity score: tensor(0.6592)
AQL information not found in the most similar passage.
