In [1]:
from models import GenerationModel, EmbeddingsModel, MilvusClient
import os

# # Initialize the models
embedding_model = EmbeddingsModel("sentence-transformers/bert-base-nli-mean-tokens")
embedding_model.load_model()
print('Embedding model loaded')

home_dir = os.getenv("HOME")
generation_model_name = f'{home_dir}/ext-gits/Mistral-7B-Instruct-v0.3'
generation_model = GenerationModel(generation_model_name)
generation_model.load_model()
print('Mistral loaded')

# Initialize the Milvus client and connect to the database
milvus_client = MilvusClient("wiki_movie_plots")
print('Milvus loaded')


Embedding model loaded


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Mistral loaded
Connected to Milvus at 127.0.0.1:19530 with alias 'default'
Collection 'wiki_movie_plots' loaded
Milvus loaded


In [2]:

print('Ingesting data')
# Insert embeddings to Milvus if needed (takes time)
csv_file_path = "data/wiki_movie_plots_deduped.csv"
milvus_client.get_collection()
milvus_client.ingest_data(csv_file_path, embedding_model, batch_size=128)
print('Data ingested')

Ingesting data


Inserting Batches:   0%|          | 0/273 [00:00<?, ?batch/s]

Number of records in the collection: 34886
Data ingested


In [3]:
# milvus_client.drop_collection()

# milvus_client.get_collection()

milvus_client.count_records()

Number of records in the collection: 34886


In [None]:

plot_query = "Aliens seize the capitol."

embedding = embedding_model.get_batch_embeddings([plot_query])
search_results = milvus_client.search(embedding)

context_texts = ''
for result in search_results:
    for hit in result:
        # print returns for context
        print(hit.entity.get("title"))
        print(hit.entity.get("release_year"))
        print(hit.entity.get("plot_text"))
        plot_text = hit.entity.get("plot_text")
        context_texts += plot_text

prompt = "Summarize this: " + context_texts
response = generation_model.generate_response(
     prompt,
     max_new_tokens=250,
     num_return_sequences=1,
     do_sample=True,
     top_p=0.95,
     top_k=50, 
     )

print(response)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Embedding type: <class 'list'>, First element type: <class 'float'>
X from Outer Space !The X from Outer Space
1967
the spaceship aab gamma is dispatched from japan to travel to mars to investigate reports of ufos in the area. when the gamma nears the red planet, it comes across a mysterious alien vessel that sprays the ship with spores. samples are taken back to earth where one of them begins to develop. the cosmic spore grows into a giant, lizard-like creature dubbed "guilala." the monster begins a reign of destruction through tokyo. it spits fireballs, feeds on nuclear fuel, turns into a flying, burning sphere and destroys any airplanes and tanks in its path. guilala is finally defeated by jets laden with bombs, which coat it in a substance called "guilalalium." it causes guilala to shrink down to its original spore form. the government promptly launches the spore back into space, where it will circle the sun in an endless orbit.
Invasion
1966
an alien spacecraft crash-lands on eart