In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torch

# import
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.document_loaders import TextLoader
from langchain.embeddings import HuggingFaceEmbeddings

from langchain import LLMChain
from langchain.chains.mapreduce import MapReduceChain
from langchain.prompts import PromptTemplate
from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModel, AutoConfig
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from torch.mps import empty_cache
import torch
from langchain.chains import RetrievalQA

torch.manual_seed(1234)

from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from sentence_transformers import SentenceTransformer, util
from typing import Any, List, Optional
from pydantic import BaseModel


In [3]:
class QwenRunnable(BaseModel):
    model: Any
    tokenizer: Any
    device: str = "cuda:1"

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        response, _ = self.model.chat(self.tokenizer, query=prompt, history=None)
        return response

    @property
    def _llm_type(self) -> str:
        return "qwen"

class Qwen:
    def __init__(self, model_path: str, device: str = "cuda:1"):
        self.model_path = model_path
        self.device = device
        self.tokenizer = None
        self.model = None
        self.llm_runnable = None

    def load_model(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(self.model_path, trust_remote_code=True)
        self.model.to(self.device)
        self.model.eval()
        self.model.generation_config = GenerationConfig.from_pretrained(self.model_path, trust_remote_code=True)
        self.llm_runnable = QwenRunnable(model=self.model, tokenizer=self.tokenizer, device=self.device)

    def generate_response(self, prompt: str, history: list = None):
        response, history = self.model.chat(self.tokenizer, query=prompt, history=history)
        return response

class TopicGPTWithQwen(Qwen):
    def generate_topics(self, documents, example_topics):
        history = []
        topics = example_topics.copy()
        for doc in documents:
            prompt = f"Document: {doc}\nExample Topics: {example_topics}\nGenerate a new topic if the document doesn't fit existing topics."
            response = self.generate_response(prompt, history)
            topics.append(response)
        return topics

    def refine_topics(self, topics):
        model = SentenceTransformer('/data1/dxw_data/llm/paraphrase-multilingual-MiniLM-L12-v2')
        topic_embeddings = model.encode(topics, convert_to_tensor=True)
        refined_topics = []
        for i in range(len(topics)):
            if topics[i] not in refined_topics:
                for j in range(i + 1, len(topics)):
                    if util.cos_sim(topic_embeddings[i], topic_embeddings[j]) >= 0.5:
                        break
                else:
                    refined_topics.append(topics[i])
        return refined_topics

    def assign_topics(self, documents, topics):
        history = []
        assignments = {}
        for doc in documents:
            prompt = f"Document: {doc}\nTopics: {topics}\nAssign the most relevant topic to the document and provide a quote."
            response = self.generate_response(prompt, history)
            assignments[doc] = response
        return assignments

    def self_correct(self, assignments):
        history = []
        corrected_assignments = {}
        for doc, assignment in assignments.items():
            if "None" in assignment or "Error" in assignment:
                prompt = f"Document: {doc}\nError: {assignment}\nPlease reassign a valid topic."
                response = self.generate_response(prompt, history)
                corrected_assignments[doc] = response
            else:
                corrected_assignments[doc] = assignment
        return corrected_assignments
    
# Path to the model directory
model_path = "/data1/dxw_data/llm/Qwen-VL-Chat"
# Specify the device (e.g., 'cuda:1', 'cuda:1')
device = 'cuda:1'

# Instantiate and load the model
qwen_model = TopicGPTWithQwen(model_path, device)
qwen_model.load_model()

2024-06-15 10:04:19.169209: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-06-15 10:04:19.291374: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-15 10:04:19.878677: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64
2024-06-15 10:04:19.878751: W tensorflow/compiler/xla/stream_exec

Loading checkpoint shards:   0%|          | 0/10 [00:00<?, ?it/s]

In [4]:
# 示例数据
documents = [
    "The stock market saw a significant increase in value due to positive economic policies.",
    "New agricultural techniques have improved crop yields significantly.",
    "Tech companies are investing heavily in artificial intelligence research.",
    "Economic growth is expected to continue with new trade agreements.",
    "Farmers are adopting new technologies to boost production."
]

example_topics = [
    "Economy: Mentions policies, growth, and financial markets.",
    "Agriculture: Discusses farming techniques, crop yields, and agricultural policies."
]

In [None]:
# ------------------更加接近论文的topicgpt建模

In [5]:
# 生成主题
generated_topics = qwen_model.generate_topics(documents, example_topics)
print("Generated Topics:")
print(generated_topics)

# 精炼主题
refined_topics = qwen_model.refine_topics(generated_topics)
print("Refined Topics:")
print(refined_topics)

# 分配主题
assignments = qwen_model.assign_topics(documents, refined_topics)
print("Topic Assignments:")
for doc, assignment in assignments.items():
    print(f"Document: {doc}\nAssignment: {assignment}\n")

# 自我修正
corrected_assignments = qwen_model.self_correct(assignments)
print("Corrected Topic Assignments:")
for doc, assignment in corrected_assignments.items():
    print(f"Document: {doc}\nCorrected Assignment: {assignment}\n")

Generated Topics:
['Economy: Mentions policies, growth, and financial markets.', 'Agriculture: Discusses farming techniques, crop yields, and agricultural policies.', "['Investment: Analyzes the impact of positive economic policies on the stock market and financial markets as a whole']", "['Agriculture: Examines the benefits of new agricultural techniques on crop yields and their potential impact on the economy and financial markets']", "['Tech Industry: Analyzes the potential impact of increased investment in artificial intelligence research on the tech industry and the economy as a whole']", "['Trade: Examines the potential benefits and risks of new trade agreements on economic growth and the financial markets']", "['Agriculture: Analyzes the impact of new technologies on crop yields and the potential benefits for farmers and the economy as a whole']"]
Refined Topics:
["['Tech Industry: Analyzes the potential impact of increased investment in artificial intelligence research on the t